forked from OSchip/llvm-project
[mlir] Add a generic while/do-while loop to the SCF dialect
The new construct represents a generic loop with two regions: one executed before the loop condition is verifier and another after that. This construct can be used to express both a "while" loop and a "do-while" loop, depending on where the main payload is located. It is intended as an intermediate abstraction for lowering, which will be added later. This form is relatively easy to target from higher-level abstractions and supports transformations such as loop rotation and LICM. Differential Revision: https://reviews.llvm.org/D90255
This commit is contained in:
parent
3bec07f91f
commit
79716559b5
|
@ -36,6 +36,25 @@ class SCF_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ConditionOp : SCF_Op<"condition",
|
||||||
|
[HasParent<"WhileOp">, NoSideEffect, Terminator]> {
|
||||||
|
let summary = "loop continuation condition";
|
||||||
|
let description = [{
|
||||||
|
This operation accepts the continuation (i.e., inverse of exit) condition
|
||||||
|
of the `scf.while` construct. If its first argument is true, the "after"
|
||||||
|
region of `scf.while` is executed, with the remaining arguments forwarded
|
||||||
|
to the entry block of the region. Otherwise, the loop terminates.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins I1:$condition, Variadic<AnyType>:$args);
|
||||||
|
|
||||||
|
let assemblyFormat =
|
||||||
|
[{ `(` $condition `)` attr-dict ($args^ `:` type($args))? }];
|
||||||
|
|
||||||
|
// Override the default verifier, everything is checked by traits.
|
||||||
|
let verifier = ?;
|
||||||
|
}
|
||||||
|
|
||||||
def ForOp : SCF_Op<"for",
|
def ForOp : SCF_Op<"for",
|
||||||
[DeclareOpInterfaceMethods<LoopLikeOpInterface>,
|
[DeclareOpInterfaceMethods<LoopLikeOpInterface>,
|
||||||
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||||
|
@ -413,8 +432,135 @@ def ReduceReturnOp :
|
||||||
let assemblyFormat = "$result attr-dict `:` type($result)";
|
let assemblyFormat = "$result attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def WhileOp : SCF_Op<"while",
|
||||||
|
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||||
|
RecursiveSideEffects]> {
|
||||||
|
let summary = "a generic 'while' loop";
|
||||||
|
let description = [{
|
||||||
|
This operation represents a generic "while"/"do-while" loop that keeps
|
||||||
|
iterating as long as a condition is satisfied. There is no restriction on
|
||||||
|
the complexity of the condition. It consists of two regions (with single
|
||||||
|
block each): "before" region and "after" region. The names of regions
|
||||||
|
indicates whether they execute before or after the condition check.
|
||||||
|
Therefore, if the main loop payload is located in the "before" region, the
|
||||||
|
operation is a "do-while" loop. Otherwise, it is a "while" loop.
|
||||||
|
|
||||||
|
The "before" region terminates with a special operation, `scf.condition`,
|
||||||
|
that accepts as its first operand an `i1` value indicating whether to
|
||||||
|
proceed to the "after" region (value is `true`) or not. The two regions
|
||||||
|
communicate by means of region arguments. Initially, the "before" region
|
||||||
|
accepts as arguments the operands of the `scf.while` operation and uses them
|
||||||
|
to evaluate the condition. It forwards the trailing, non-condition operands
|
||||||
|
of the `scf.condition` terminator either to the "after" region if the
|
||||||
|
control flow is transferred there or to results of the `scf.while` operation
|
||||||
|
otherwise. The "after" region takes as arguments the values produced by the
|
||||||
|
"before" region and uses `scf.yield` to supply new arguments for the "after"
|
||||||
|
region, into which it transfers the control flow unconditionally.
|
||||||
|
|
||||||
|
A simple "while" loop can be represented as follows.
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%res = scf.while (%arg1 = %init1) : (f32) -> f32 {
|
||||||
|
/* "Before" region.
|
||||||
|
* In a "while" loop, this region computes the condition. */
|
||||||
|
%condition = call @evaluate_condition(%arg1) : (f32) -> i1
|
||||||
|
|
||||||
|
/* Forward the argument (as result or "after" region argument). */
|
||||||
|
scf.condition(%condition) %arg1 : f32
|
||||||
|
|
||||||
|
} do {
|
||||||
|
^bb0(%arg2: f32):
|
||||||
|
/* "After region.
|
||||||
|
* In a "while" loop, this region is the loop body. */
|
||||||
|
%next = call @payload(%arg2) : (f32) -> f32
|
||||||
|
|
||||||
|
/* Forward the new value to the "before" region.
|
||||||
|
* The operand types must match the types of the `scf.while` operands. */
|
||||||
|
scf.yield %next : f32
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
A simple "do-while" loop can be represented by reducing the "after" block
|
||||||
|
to a simple forwarder.
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%res = scf.while (%arg1 = %init1) : (f32) -> f32 {
|
||||||
|
/* "Before" region.
|
||||||
|
* In a "do-while" loop, this region contains the loop body. */
|
||||||
|
%next = call @payload(%arg1) : (f32) -> f32
|
||||||
|
|
||||||
|
/* And also evalutes the condition. */
|
||||||
|
%condition = call @evaluate_condition(%arg1) : (f32) -> i1
|
||||||
|
|
||||||
|
/* Loop through the "after" region. */
|
||||||
|
scf.condition(%condition) %next : f32
|
||||||
|
|
||||||
|
} do {
|
||||||
|
^bb0(%arg2: f32):
|
||||||
|
/* "After" region.
|
||||||
|
* Forwards the values back to "before" region unmodified. */
|
||||||
|
scf.yield %arg2 : f32
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that the types of region arguments need not to match with each other.
|
||||||
|
The op expects the operand types to match with argument types of the
|
||||||
|
"before" region"; the result types to match with the trailing operand types
|
||||||
|
of the terminator of the "before" region, and with the argument types of the
|
||||||
|
"after" region. The following scheme can be used to share the results of
|
||||||
|
some operations executed in the "before" region with the "after" region,
|
||||||
|
avoiding the need to recompute them.
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%res = scf.while (%arg1 = %init1) : (f32) -> i64 {
|
||||||
|
/* One can perform some computations, e.g., necessary to evaluate the
|
||||||
|
* condition, in the "before" region and forward their results to the
|
||||||
|
* "after" region. */
|
||||||
|
%shared = call @shared_compute(%arg1) : (f32) -> i64
|
||||||
|
|
||||||
|
/* Evalute the condition. */
|
||||||
|
%condition = call @evaluate_condition(%arg1, %shared) : (f32, i64) -> i1
|
||||||
|
|
||||||
|
/* Forward the result of the shared computation to the "after" region.
|
||||||
|
* The types must match the arguments of the "after" region as well as
|
||||||
|
* those of the `scf.while` results. */
|
||||||
|
scf.condition(%condition) %shared : i64
|
||||||
|
|
||||||
|
} do {
|
||||||
|
^bb0(%arg2: i64) {
|
||||||
|
/* Use the partial result to compute the rest of the payload in the
|
||||||
|
* "after" region. */
|
||||||
|
%res = call @payload(%arg2) : (i64) -> f32
|
||||||
|
|
||||||
|
/* Forward the new value to the "before" region.
|
||||||
|
* The operand types must match the types of the `scf.while` operands. */
|
||||||
|
scf.yield %res : f32
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The custom syntax for this operation is as follows.
|
||||||
|
|
||||||
|
```
|
||||||
|
op ::= `scf.while` assignments `:` function-type region `do` region
|
||||||
|
`attributes` attribute-dict
|
||||||
|
initializer ::= /* empty */ | `(` assignment-list `)`
|
||||||
|
assignment-list ::= assignment | assignment `,` assignment-list
|
||||||
|
assignment ::= ssa-value `=` ssa-value
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyType>:$inits);
|
||||||
|
let results = (outs Variadic<AnyType>:$results);
|
||||||
|
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
OperandRange getSuccessorEntryOperands(unsigned index);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
|
def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
|
||||||
ParentOneOf<["IfOp, ForOp", "ParallelOp"]>]> {
|
ParentOneOf<["IfOp, ForOp", "ParallelOp",
|
||||||
|
"WhileOp"]>]> {
|
||||||
let summary = "loop yield and termination operation";
|
let summary = "loop yield and termination operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"scf.yield" yields an SSA value from the SCF dialect op region and
|
"scf.yield" yields an SSA value from the SCF dialect op region and
|
||||||
|
@ -434,4 +580,5 @@ def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
|
||||||
// needed.
|
// needed.
|
||||||
let verifier = ?;
|
let verifier = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_SCF_SCFOPS
|
#endif // MLIR_DIALECT_SCF_SCFOPS
|
||||||
|
|
|
@ -755,10 +755,17 @@ public:
|
||||||
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
|
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
|
||||||
|
|
||||||
/// Parse a list of assignments of the form
|
/// Parse a list of assignments of the form
|
||||||
/// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
|
/// (%x1 = %y1, %x2 = %y2, ...)
|
||||||
/// The list must contain at least one entry
|
ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
|
||||||
virtual ParseResult
|
SmallVectorImpl<OperandType> &rhs) {
|
||||||
parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
|
OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
|
||||||
|
if (!result.hasValue())
|
||||||
|
return emitError(getCurrentLocation(), "expected '('");
|
||||||
|
return result.getValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual OptionalParseResult
|
||||||
|
parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
|
||||||
SmallVectorImpl<OperandType> &rhs) = 0;
|
SmallVectorImpl<OperandType> &rhs) = 0;
|
||||||
|
|
||||||
/// Parse a keyword followed by a type.
|
/// Parse a keyword followed by a type.
|
||||||
|
|
|
@ -140,26 +140,37 @@ static LogicalResult verify(ForOp op) {
|
||||||
return RegionBranchOpInterface::verifyTypes(op);
|
return RegionBranchOpInterface::verifyTypes(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, ForOp op) {
|
/// Prints the initialization list in the form of
|
||||||
bool printBlockTerminators = false;
|
/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
|
||||||
p << op.getOperationName() << " " << op.getInductionVar() << " = "
|
/// where 'inner' values are assumed to be region arguments and 'outer' values
|
||||||
<< op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
|
/// are regular SSA values.
|
||||||
|
static void printInitializationList(OpAsmPrinter &p,
|
||||||
|
Block::BlockArgListType blocksArgs,
|
||||||
|
ValueRange initializers,
|
||||||
|
StringRef prefix = "") {
|
||||||
|
assert(blocksArgs.size() == initializers.size() &&
|
||||||
|
"expected same length of arguments and initializers");
|
||||||
|
if (initializers.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
if (op.hasIterOperands()) {
|
p << prefix << '(';
|
||||||
p << " iter_args(";
|
llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
|
||||||
auto regionArgs = op.getRegionIterArgs();
|
|
||||||
auto operands = op.getIterOperands();
|
|
||||||
|
|
||||||
llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
|
|
||||||
p << std::get<0>(it) << " = " << std::get<1>(it);
|
p << std::get<0>(it) << " = " << std::get<1>(it);
|
||||||
});
|
});
|
||||||
p << ")";
|
p << ")";
|
||||||
p << " -> (" << op.getResultTypes() << ")";
|
}
|
||||||
printBlockTerminators = true;
|
|
||||||
}
|
static void print(OpAsmPrinter &p, ForOp op) {
|
||||||
|
p << op.getOperationName() << " " << op.getInductionVar() << " = "
|
||||||
|
<< op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
|
||||||
|
|
||||||
|
printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(),
|
||||||
|
" iter_args");
|
||||||
|
if (!op.getIterOperands().empty())
|
||||||
|
p << " -> (" << op.getIterOperands().getTypes() << ')';
|
||||||
p.printRegion(op.region(),
|
p.printRegion(op.region(),
|
||||||
/*printEntryBlockArgs=*/false,
|
/*printEntryBlockArgs=*/false,
|
||||||
/*printBlockTerminators=*/printBlockTerminators);
|
/*printBlockTerminators=*/op.hasIterOperands());
|
||||||
p.printOptionalAttrDict(op.getAttrs());
|
p.printOptionalAttrDict(op.getAttrs());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -933,6 +944,158 @@ static LogicalResult verify(ReduceReturnOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// WhileOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
|
||||||
|
assert(index == 0 &&
|
||||||
|
"WhileOp is expected to branch only to the first region");
|
||||||
|
|
||||||
|
return inits();
|
||||||
|
}
|
||||||
|
|
||||||
|
void WhileOp::getSuccessorRegions(Optional<unsigned> index,
|
||||||
|
ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
|
(void)operands;
|
||||||
|
|
||||||
|
if (!index.hasValue()) {
|
||||||
|
regions.emplace_back(&before(), before().getArguments());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(*index < 2 && "there are only two regions in a WhileOp");
|
||||||
|
if (*index == 0) {
|
||||||
|
regions.emplace_back(&after(), after().getArguments());
|
||||||
|
regions.emplace_back(getResults());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
regions.emplace_back(&before(), before().getArguments());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parses a `while` op.
|
||||||
|
///
|
||||||
|
/// op ::= `scf.while` assignments `:` function-type region `do` region
|
||||||
|
/// `attributes` attribute-dict
|
||||||
|
/// initializer ::= /* empty */ | `(` assignment-list `)`
|
||||||
|
/// assignment-list ::= assignment | assignment `,` assignment-list
|
||||||
|
/// assignment ::= ssa-value `=` ssa-value
|
||||||
|
static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) {
|
||||||
|
SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
|
||||||
|
Region *before = result.addRegion();
|
||||||
|
Region *after = result.addRegion();
|
||||||
|
|
||||||
|
OptionalParseResult listResult =
|
||||||
|
parser.parseOptionalAssignmentList(regionArgs, operands);
|
||||||
|
if (listResult.hasValue() && failed(listResult.getValue()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
FunctionType functionType;
|
||||||
|
llvm::SMLoc typeLoc = parser.getCurrentLocation();
|
||||||
|
if (failed(parser.parseColonType(functionType)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
result.addTypes(functionType.getResults());
|
||||||
|
|
||||||
|
if (functionType.getNumInputs() != operands.size()) {
|
||||||
|
return parser.emitError(typeLoc)
|
||||||
|
<< "expected as many input types as operands "
|
||||||
|
<< "(expected " << operands.size() << " got "
|
||||||
|
<< functionType.getNumInputs() << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve input operands.
|
||||||
|
if (failed(parser.resolveOperands(operands, functionType.getInputs(),
|
||||||
|
parser.getCurrentLocation(),
|
||||||
|
result.operands)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return failure(
|
||||||
|
parser.parseRegion(*before, regionArgs, functionType.getInputs()) ||
|
||||||
|
parser.parseKeyword("do") || parser.parseRegion(*after) ||
|
||||||
|
parser.parseOptionalAttrDictWithKeyword(result.attributes));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prints a `while` op.
|
||||||
|
static void print(OpAsmPrinter &p, scf::WhileOp op) {
|
||||||
|
p << op.getOperationName();
|
||||||
|
printInitializationList(p, op.before().front().getArguments(), op.inits(),
|
||||||
|
" ");
|
||||||
|
p << " : ";
|
||||||
|
p.printFunctionalType(op.inits().getTypes(), op.results().getTypes());
|
||||||
|
p.printRegion(op.before(), /*printEntryBlockArgs=*/false);
|
||||||
|
p << " do";
|
||||||
|
p.printRegion(op.after());
|
||||||
|
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verifies that two ranges of types match, i.e. have the same number of
|
||||||
|
/// entries and that types are pairwise equals. Reports errors on the given
|
||||||
|
/// operation in case of mismatch.
|
||||||
|
template <typename OpTy>
|
||||||
|
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
|
||||||
|
TypeRange right, StringRef message) {
|
||||||
|
if (left.size() != right.size())
|
||||||
|
return op.emitOpError("expects the same number of ") << message;
|
||||||
|
|
||||||
|
for (unsigned i = 0, e = left.size(); i < e; ++i) {
|
||||||
|
if (left[i] != right[i]) {
|
||||||
|
InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
|
||||||
|
<< message;
|
||||||
|
diag.attachNote() << "for argument " << i << ", found " << left[i]
|
||||||
|
<< " and " << right[i];
|
||||||
|
return diag;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verifies that the first block of the given `region` is terminated by a
|
||||||
|
/// YieldOp. Reports errors on the given operation if it is not the case.
|
||||||
|
template <typename TerminatorTy>
|
||||||
|
static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region ®ion,
|
||||||
|
StringRef errorMessage) {
|
||||||
|
Operation *terminatorOperation = region.front().getTerminator();
|
||||||
|
if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
|
||||||
|
return yield;
|
||||||
|
|
||||||
|
auto diag = op.emitOpError(errorMessage);
|
||||||
|
if (terminatorOperation)
|
||||||
|
diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verify(scf::WhileOp op) {
|
||||||
|
if (failed(RegionBranchOpInterface::verifyTypes(op)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
|
||||||
|
op, op.before(),
|
||||||
|
"expects the 'before' region to terminate with 'scf.condition'");
|
||||||
|
if (!beforeTerminator)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
TypeRange trailingTerminatorOperands = beforeTerminator.args().getTypes();
|
||||||
|
if (failed(verifyTypeRangesMatch(op, trailingTerminatorOperands,
|
||||||
|
op.after().getArgumentTypes(),
|
||||||
|
"trailing operands of the 'before' block "
|
||||||
|
"terminator and 'after' region arguments")))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (failed(verifyTypeRangesMatch(
|
||||||
|
op, trailingTerminatorOperands, op.getResultTypes(),
|
||||||
|
"trailing operands of the 'before' block terminator and op results")))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
|
||||||
|
op, op.after(),
|
||||||
|
"expects the 'after' region to terminate with 'scf.yield'");
|
||||||
|
return success(afterTerminator != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// YieldOp
|
// YieldOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -76,10 +76,13 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
||||||
/// Verify that types match along all region control flow edges originating from
|
/// Verify that types match along all region control flow edges originating from
|
||||||
/// `sourceNo` (region # if source is a region, llvm::None if source is parent
|
/// `sourceNo` (region # if source is a region, llvm::None if source is parent
|
||||||
/// op). `getInputsTypesForRegion` is a function that returns the types of the
|
/// op). `getInputsTypesForRegion` is a function that returns the types of the
|
||||||
/// inputs that flow from `sourceIndex' to the given region.
|
/// inputs that flow from `sourceIndex' to the given region, or llvm::None if
|
||||||
static LogicalResult verifyTypesAlongAllEdges(
|
/// the exact type match verification is not necessary (e.g., if the Op verifies
|
||||||
Operation *op, Optional<unsigned> sourceNo,
|
/// the match itself).
|
||||||
function_ref<TypeRange(Optional<unsigned>)> getInputsTypesForRegion) {
|
static LogicalResult
|
||||||
|
verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
|
||||||
|
function_ref<Optional<TypeRange>(Optional<unsigned>)>
|
||||||
|
getInputsTypesForRegion) {
|
||||||
auto regionInterface = cast<RegionBranchOpInterface>(op);
|
auto regionInterface = cast<RegionBranchOpInterface>(op);
|
||||||
|
|
||||||
SmallVector<RegionSuccessor, 2> successors;
|
SmallVector<RegionSuccessor, 2> successors;
|
||||||
|
@ -113,17 +116,20 @@ static LogicalResult verifyTypesAlongAllEdges(
|
||||||
return diag;
|
return diag;
|
||||||
};
|
};
|
||||||
|
|
||||||
TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo);
|
Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
|
||||||
|
if (!sourceTypes.hasValue())
|
||||||
|
continue;
|
||||||
|
|
||||||
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
|
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
|
||||||
if (sourceTypes.size() != succInputsTypes.size()) {
|
if (sourceTypes->size() != succInputsTypes.size()) {
|
||||||
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
|
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
|
||||||
return printEdgeName(diag) << ": source has " << sourceTypes.size()
|
return printEdgeName(diag) << ": source has " << sourceTypes->size()
|
||||||
<< " operands, but target successor needs "
|
<< " operands, but target successor needs "
|
||||||
<< succInputsTypes.size();
|
<< succInputsTypes.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto typesIdx :
|
for (auto typesIdx :
|
||||||
llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) {
|
llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
|
||||||
Type sourceType = std::get<0>(typesIdx.value());
|
Type sourceType = std::get<0>(typesIdx.value());
|
||||||
Type inputType = std::get<1>(typesIdx.value());
|
Type inputType = std::get<1>(typesIdx.value());
|
||||||
if (sourceType != inputType) {
|
if (sourceType != inputType) {
|
||||||
|
@ -191,10 +197,15 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
|
||||||
<< " operands mismatch between return-like terminators";
|
<< " operands mismatch between return-like terminators";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputTypesFromRegion = [&](Optional<unsigned> regionNo) -> TypeRange {
|
auto inputTypesFromRegion =
|
||||||
|
[&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
|
||||||
|
// If there is no return-like terminator, the op itself should verify
|
||||||
|
// type consistency.
|
||||||
|
if (!regionReturn)
|
||||||
|
return llvm::None;
|
||||||
|
|
||||||
// All successors get the same set of operands.
|
// All successors get the same set of operands.
|
||||||
return regionReturn ? TypeRange(regionReturn->getOperands().getTypes())
|
return TypeRange(regionReturn->getOperands().getTypes());
|
||||||
: TypeRange();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
|
if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
|
||||||
|
|
|
@ -1480,10 +1480,13 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse a list of assignments of the form
|
/// Parse a list of assignments of the form
|
||||||
/// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
|
/// (%x1 = %y1, %x2 = %y2, ...).
|
||||||
/// The list must contain at least one entry
|
OptionalParseResult
|
||||||
ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
|
parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
|
||||||
SmallVectorImpl<OperandType> &rhs) override {
|
SmallVectorImpl<OperandType> &rhs) override {
|
||||||
|
if (failed(parseOptionalLParen()))
|
||||||
|
return llvm::None;
|
||||||
|
|
||||||
auto parseElt = [&]() -> ParseResult {
|
auto parseElt = [&]() -> ParseResult {
|
||||||
OperandType regionArg, operand;
|
OperandType regionArg, operand;
|
||||||
if (parseRegionArgument(regionArg) || parseEqual() ||
|
if (parseRegionArgument(regionArg) || parseEqual() ||
|
||||||
|
@ -1493,8 +1496,6 @@ public:
|
||||||
rhs.push_back(operand);
|
rhs.push_back(operand);
|
||||||
return success();
|
return success();
|
||||||
};
|
};
|
||||||
if (parseLParen())
|
|
||||||
return failure();
|
|
||||||
return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
|
return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -425,10 +425,88 @@ func @parallel_invalid_yield(
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @yield_invalid_parent_op() {
|
func @yield_invalid_parent_op() {
|
||||||
"my.op"() ({
|
"my.op"() ({
|
||||||
// expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel'}}
|
// expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel, scf.while'}}
|
||||||
scf.yield
|
scf.yield
|
||||||
}) : () -> ()
|
}) : () -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @while_parser_type_mismatch() {
|
||||||
|
%true = constant true
|
||||||
|
// expected-error@+1 {{expected as many input types as operands (expected 0 got 1)}}
|
||||||
|
scf.while : (i32) -> () {
|
||||||
|
scf.condition(%true)
|
||||||
|
} do {
|
||||||
|
scf.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @while_bad_terminator() {
|
||||||
|
// expected-error@+1 {{expects the 'before' region to terminate with 'scf.condition'}}
|
||||||
|
scf.while : () -> () {
|
||||||
|
// expected-note@+1 {{terminator here}}
|
||||||
|
"some.other_terminator"() : () -> ()
|
||||||
|
} do {
|
||||||
|
scf.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @while_cross_region_type_mismatch() {
|
||||||
|
%true = constant true
|
||||||
|
// expected-error@+1 {{expects the same number of trailing operands of the 'before' block terminator and 'after' region arguments}}
|
||||||
|
scf.while : () -> () {
|
||||||
|
scf.condition(%true)
|
||||||
|
} do {
|
||||||
|
^bb0(%arg0: i32):
|
||||||
|
scf.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @while_cross_region_type_mismatch() {
|
||||||
|
%true = constant true
|
||||||
|
// expected-error@+2 {{expects the same types for trailing operands of the 'before' block terminator and 'after' region arguments}}
|
||||||
|
// expected-note@+1 {{for argument 0, found 'i1' and 'i32}}
|
||||||
|
scf.while : () -> () {
|
||||||
|
scf.condition(%true) %true : i1
|
||||||
|
} do {
|
||||||
|
^bb0(%arg0: i32):
|
||||||
|
scf.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @while_result_type_mismatch() {
|
||||||
|
%true = constant true
|
||||||
|
// expected-error@+1 {{expects the same number of trailing operands of the 'before' block terminator and op results}}
|
||||||
|
scf.while : () -> () {
|
||||||
|
scf.condition(%true) %true : i1
|
||||||
|
} do {
|
||||||
|
^bb0(%arg0: i1):
|
||||||
|
scf.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @while_bad_terminator() {
|
||||||
|
%true = constant true
|
||||||
|
// expected-error@+1 {{expects the 'after' region to terminate with 'scf.yield'}}
|
||||||
|
scf.while : () -> () {
|
||||||
|
scf.condition(%true)
|
||||||
|
} do {
|
||||||
|
// expected-note@+1 {{terminator here}}
|
||||||
|
"some.other_terminator"() : () -> ()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -240,3 +240,42 @@ func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %ste
|
||||||
// CHECK-NEXT: scf.yield %[[IFRES]] : f32
|
// CHECK-NEXT: scf.yield %[[IFRES]] : f32
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: return %[[RESULT]]
|
// CHECK-NEXT: return %[[RESULT]]
|
||||||
|
|
||||||
|
// CHECK-LABEL: @while
|
||||||
|
func @while() {
|
||||||
|
%0 = "test.get_some_value"() : () -> i32
|
||||||
|
%1 = "test.get_some_value"() : () -> f32
|
||||||
|
|
||||||
|
// CHECK: = scf.while (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : (i32, f32) -> (i64, f64) {
|
||||||
|
%2:2 = scf.while (%arg0 = %0, %arg1 = %1) : (i32, f32) -> (i64, f64) {
|
||||||
|
%3:2 = "test.some_operation"(%arg0, %arg1) : (i32, f32) -> (i64, f64)
|
||||||
|
%4 = "test.some_condition"(%arg0, %arg1) : (i32, f32) -> i1
|
||||||
|
// CHECK: scf.condition(%{{.*}}) %{{.*}}, %{{.*}} : i64, f64
|
||||||
|
scf.condition(%4) %3#0, %3#1 : i64, f64
|
||||||
|
// CHECK: } do {
|
||||||
|
} do {
|
||||||
|
// CHECK: ^{{.*}}(%{{.*}}: i64, %{{.*}}: f64):
|
||||||
|
^bb0(%arg2: i64, %arg3: f64):
|
||||||
|
%5:2 = "test.some_operation"(%arg2, %arg3): (i64, f64) -> (i32, f32)
|
||||||
|
// CHECK: scf.yield %{{.*}}, %{{.*}} : i32, f32
|
||||||
|
scf.yield %5#0, %5#1 : i32, f32
|
||||||
|
// CHECK: attributes {foo = "bar"}
|
||||||
|
} attributes {foo="bar"}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @infinite_while
|
||||||
|
func @infinite_while() {
|
||||||
|
%true = constant true
|
||||||
|
|
||||||
|
// CHECK: scf.while : () -> () {
|
||||||
|
scf.while : () -> () {
|
||||||
|
// CHECK: scf.condition(%{{.*}})
|
||||||
|
scf.condition(%true)
|
||||||
|
// CHECK: } do {
|
||||||
|
} do {
|
||||||
|
// CHECK: scf.yield
|
||||||
|
scf.yield
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue