forked from OSchip/llvm-project
[MLIR] Allow Loop dialect IfOp and ForOp to define values
This patch implements the RFCs proposed here: https://llvm.discourse.group/t/rfc-modify-ifop-in-loop-dialect-to-yield-values/463 https://llvm.discourse.group/t/rfc-adding-operands-and-results-to-loop-for/459/19. It introduces the following changes: - All Loop Ops region, except for ReduceOp, terminate with a YieldOp. - YieldOp can have variadice operands that is used to return values out of IfOp and ForOp regions. - Change IfOp and ForOp syntax and representation to define values. - Add unit-tests and update .td documentation. - YieldOp is a terminator to loop.for/if/parallel - YieldOp custom parser and printer Lowering is not supported at the moment, and will be in a follow-up PR. Thanks. Reviewed By: bondhugula, nicolasvasilache, rriddle Differential Revision: https://reviews.llvm.org/D74174
This commit is contained in:
parent
31ec721516
commit
bc7b26c333
|
@ -79,7 +79,7 @@ public:
|
|||
if (i != e - 1)
|
||||
rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
|
||||
newLineCst);
|
||||
rewriter.create<loop::TerminatorOp>(loc);
|
||||
rewriter.create<loop::YieldOp>(loc);
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
}
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ public:
|
|||
if (i != e - 1)
|
||||
rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
|
||||
newLineCst);
|
||||
rewriter.create<loop::TerminatorOp>(loc);
|
||||
rewriter.create<loop::YieldOp>(loc);
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
}
|
||||
|
||||
|
|
|
@ -37,31 +37,88 @@ class Loop_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
|
||||
def ForOp : Loop_Op<"for",
|
||||
[DeclareOpInterfaceMethods<LoopLikeOpInterface>,
|
||||
SingleBlockImplicitTerminator<"TerminatorOp">]> {
|
||||
SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
let summary = "for operation";
|
||||
let description = [{
|
||||
The "loop.for" operation represents a loop nest taking 3 SSA value as
|
||||
The "loop.for" operation represents a loop taking 3 SSA value as
|
||||
operands that represent the lower bound, upper bound and step respectively.
|
||||
The operation defines an SSA value for its induction variable. It has one
|
||||
region capturing the loop body. The induction variable is represented as an
|
||||
argument of this region. This SSA value always has type index, which is the
|
||||
size of the machine word. The step is a value of type index, required to be
|
||||
positive.
|
||||
The lower and upper bounds specify a half-open range: the range includes the
|
||||
lower bound but does not include the upper bound.
|
||||
The lower and upper bounds specify a half-open range: the range includes
|
||||
the lower bound but does not include the upper bound.
|
||||
|
||||
The body region must contain exactly one block that terminates with
|
||||
"loop.terminator". Calling ForOp::build will create such region and insert
|
||||
the terminator, so will the parsing even in cases when it is absent from the
|
||||
custom format. For example:
|
||||
"loop.yield". Calling ForOp::build will create such a region and insert
|
||||
the terminator implicitly if none is defined, so will the parsing even
|
||||
in cases when it is absent from the custom format. For example:
|
||||
|
||||
```mlir
|
||||
loop.for %iv = %lb to %ub step %step {
|
||||
... // body
|
||||
}
|
||||
```
|
||||
|
||||
"loop.for" can also operate on loop-carried variables and returns the final values
|
||||
after loop termination. The initial values of the variables are passed as additional SSA
|
||||
operands to the "loop.for" following the 3 loop control SSA values mentioned above
|
||||
(lower bound, upper bound and step). The operation region has equivalent arguments
|
||||
for each variable representing the value of the variable at the current iteration.
|
||||
|
||||
The region must terminate with a "loop.yield" that passes all the current iteration
|
||||
variables to the next iteration, or to the "loop.for" result, if at the last iteration.
|
||||
"loop.for" results hold the final values after the last iteration.
|
||||
|
||||
For example, to sum-reduce a memref:
|
||||
|
||||
```mlir
|
||||
func @reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index) -> (f32) {
|
||||
// Initial sum set to 0.
|
||||
%sum_0 = constant 0.0 : f32
|
||||
// iter_args binds initial values to the loop's region arguments.
|
||||
%sum = loop.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) {
|
||||
%t = load %buffer[%iv] : memref<1024xf32>
|
||||
%sum_next = addf %sum_iter, %t : f32
|
||||
// Yield current iteration sum to next iteration %sum_iter or to %sum if final iteration.
|
||||
loop.yield %sum_next : f32
|
||||
}
|
||||
return %sum : f32
|
||||
}
|
||||
```
|
||||
|
||||
If the "loop.for" defines any values, a yield must be explicitly present.
|
||||
The number and types of the "loop.for" results must match the initial values
|
||||
in the "iter_args" binding and the yield operands.
|
||||
|
||||
Another example with a nested "loop.if" (see "loop.if" for details)
|
||||
to perform conditional reduction:
|
||||
|
||||
```mlir
|
||||
func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index) -> (f32) {
|
||||
%sum_0 = constant 0.0 : f32
|
||||
%c0 = constant 0.0 : f32
|
||||
%sum = loop.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) {
|
||||
%t = load %buffer[%iv] : memref<1024xf32>
|
||||
%cond = cmpf "ugt", %t, %c0 : f32
|
||||
%sum_next = loop.if %cond -> (f32) {
|
||||
%new_sum = addf %sum_iter, %t : f32
|
||||
loop.yield %new_sum : f32
|
||||
} else {
|
||||
loop.yield %sum_iter : f32
|
||||
}
|
||||
loop.yield %sum_next : f32
|
||||
}
|
||||
return %sum : f32
|
||||
}
|
||||
```
|
||||
}];
|
||||
let arguments = (ins Index:$lowerBound, Index:$upperBound, Index:$step);
|
||||
let arguments = (ins Index:$lowerBound,
|
||||
Index:$upperBound,
|
||||
Index:$step,
|
||||
Variadic<AnyType>:$initArgs);
|
||||
let results = (outs Variadic<AnyType>:$results);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
@ -76,19 +133,41 @@ def ForOp : Loop_Op<"for",
|
|||
OpBuilder getBodyBuilder() {
|
||||
return OpBuilder(getBody(), std::prev(getBody()->end()));
|
||||
}
|
||||
iterator_range<Block::args_iterator> getRegionIterArgs() {
|
||||
return getBody()->getArguments().drop_front();
|
||||
}
|
||||
iterator_range<Operation::operand_iterator> getIterOperands() {
|
||||
return getOperands().drop_front(getNumControlOperands());
|
||||
}
|
||||
|
||||
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
|
||||
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
|
||||
void setStep(Value step) { getOperation()->setOperand(2, step); }
|
||||
|
||||
/// Number of region arguments for loop-carried values
|
||||
unsigned getNumRegionIterArgs() {
|
||||
return getBody()->getNumArguments() - 1;
|
||||
}
|
||||
/// Number of operands controlling the loop: lb, ub, step
|
||||
constexpr unsigned getNumControlOperands() { return 3; }
|
||||
/// Does the operation hold operands for loop-carried values
|
||||
bool hasIterOperands() {
|
||||
return getOperation()->getNumOperands() > getNumControlOperands();
|
||||
}
|
||||
/// Get Number of loop-carried values
|
||||
unsigned getNumIterOperands() {
|
||||
return getOperation()->getNumOperands() - getNumControlOperands();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def IfOp : Loop_Op<"if",
|
||||
[SingleBlockImplicitTerminator<"TerminatorOp">]> {
|
||||
[SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
let summary = "if-then-else operation";
|
||||
let description = [{
|
||||
The "loop.if" operation represents an if-then-else construct for
|
||||
conditionally executing two regions of code. The operand to an if operation
|
||||
is a boolean value. The operation produces no results. For example:
|
||||
is a boolean value. For example:
|
||||
|
||||
```mlir
|
||||
loop.if %b {
|
||||
|
@ -98,9 +177,28 @@ def IfOp : Loop_Op<"if",
|
|||
}
|
||||
```
|
||||
|
||||
The 'else' block is optional, and may be omitted. For
|
||||
example:
|
||||
"loop.if" may also return results that are defined in its regions. The values
|
||||
defined are determined by which execution path is taken.
|
||||
For example:
|
||||
```mlir
|
||||
%x, %y = loop.if %b -> (f32, f32) {
|
||||
%x_true = ...
|
||||
%y_true = ...
|
||||
loop.yield %x_true, %y_true : f32, f32
|
||||
} else {
|
||||
%x_false = ...
|
||||
%y_false = ...
|
||||
loop.yield %x_false, %y_false : f32, f32
|
||||
}
|
||||
```
|
||||
|
||||
"loop.if" regions are always terminated with "loop.yield". If "loop.if"
|
||||
defines no values, the "loop.yield" can be left out, and will be
|
||||
inserted implicitly. Otherwise, it must be explicit.
|
||||
Also, if "loop.if" defines one or more values, the 'else' block cannot
|
||||
be omitted.
|
||||
|
||||
For example:
|
||||
```mlir
|
||||
loop.if %b {
|
||||
...
|
||||
|
@ -108,6 +206,7 @@ def IfOp : Loop_Op<"if",
|
|||
```
|
||||
}];
|
||||
let arguments = (ins I1:$condition);
|
||||
let results = (outs Variadic<AnyType>:$results);
|
||||
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
|
@ -131,7 +230,7 @@ def IfOp : Loop_Op<"if",
|
|||
}
|
||||
|
||||
def ParallelOp : Loop_Op<"parallel",
|
||||
[SameVariadicOperandSize, SingleBlockImplicitTerminator<"TerminatorOp">]> {
|
||||
[SameVariadicOperandSize, SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
let summary = "parallel for operation";
|
||||
let description = [{
|
||||
The "loop.parallel" operation represents a loop nest taking 3 groups of SSA
|
||||
|
@ -157,8 +256,8 @@ def ParallelOp : Loop_Op<"parallel",
|
|||
the same number of results as it has reduce operations.
|
||||
|
||||
The body region must contain exactly one block that terminates with
|
||||
"loop.terminator". Parsing ParallelOp will create such region and insert the
|
||||
terminator when it is absent from the custom format. For example:
|
||||
"loop.yield" without operands. Parsing ParallelOp will create such a region
|
||||
and insert the terminator when it is absent from the custom format. For example:
|
||||
|
||||
```mlir
|
||||
loop.parallel (%iv) = (%lb) to (%ub) step (%step) {
|
||||
|
@ -262,25 +361,23 @@ def ReduceReturnOp :
|
|||
let assemblyFormat = "$result attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TerminatorOp : Loop_Op<"terminator", [Terminator]> {
|
||||
let summary = "cf terminator operation";
|
||||
def YieldOp : Loop_Op<"yield", [Terminator]> {
|
||||
let summary = "loop yield and termination operation";
|
||||
let description = [{
|
||||
"loop.terminator" is a special terminator operation for blocks inside
|
||||
loops. It terminates the region. This operation does _not_ have a custom
|
||||
syntax. However, `std` control operations omit the terminator in their
|
||||
custom syntax for brevity.
|
||||
|
||||
```mlir
|
||||
loop.terminator
|
||||
```
|
||||
"loop.yield" yields an SSA value from a loop dialect op region and
|
||||
terminates the regions. The semantics of how the values are yielded
|
||||
is defined by the parent operation.
|
||||
If "loop.yield" has any operands, the operands must match the parent
|
||||
operation's results.
|
||||
If the parent operation defines no values, then the "loop.yield" may be
|
||||
left out in the custom syntax and the builders will insert one implicitly.
|
||||
Otherwise, it has to be present in the syntax to indicate which values
|
||||
are yielded.
|
||||
}];
|
||||
|
||||
// No custom parsing/printing form.
|
||||
let parser = ?;
|
||||
let printer = ?;
|
||||
|
||||
// Fully specified by traits.
|
||||
let verifier = ?;
|
||||
let arguments = (ins Variadic<AnyType>:$results);
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result", [{ /* nothing to do */ }]>
|
||||
];
|
||||
}
|
||||
|
||||
#endif // LOOP_OPS
|
||||
|
|
|
@ -608,6 +608,9 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Parse an arrow followed by a type list.
|
||||
virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
|
||||
|
||||
/// Parse an optional arrow followed by a type list.
|
||||
virtual ParseResult
|
||||
parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
|
||||
|
@ -641,6 +644,13 @@ public:
|
|||
virtual ParseResult
|
||||
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
|
||||
|
||||
/// Parse a list of assignments of the form
|
||||
/// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
|
||||
/// The list must contain at least one entry
|
||||
virtual ParseResult
|
||||
parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
|
||||
SmallVectorImpl<OperandType> &rhs) = 0;
|
||||
|
||||
/// Parse a keyword followed by a type.
|
||||
ParseResult parseKeywordType(const char *keyword, Type &result) {
|
||||
return failure(parseKeyword(keyword) || parseType(result));
|
||||
|
|
|
@ -332,7 +332,7 @@ public:
|
|||
|
||||
PatternMatchResult matchAndRewrite(AffineTerminatorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<loop::TerminatorOp>(op);
|
||||
rewriter.replaceOpWithNewOp<loop::YieldOp>(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -42,14 +42,13 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Pattern to erase a loop::TerminatorOp.
|
||||
class TerminatorOpConversion final
|
||||
: public SPIRVOpLowering<loop::TerminatorOp> {
|
||||
/// Pattern to erase a loop::YieldOp.
|
||||
class TerminatorOpConversion final : public SPIRVOpLowering<loop::YieldOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<loop::TerminatorOp>::SPIRVOpLowering;
|
||||
using SPIRVOpLowering<loop::YieldOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(loop::TerminatorOp terminatorOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(loop::YieldOp terminatorOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(terminatorOp);
|
||||
return matchSuccess();
|
||||
|
|
|
@ -76,18 +76,60 @@ static LogicalResult verify(ForOp op) {
|
|||
// Check that the body defines as single block argument for the induction
|
||||
// variable.
|
||||
auto *body = op.getBody();
|
||||
if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex())
|
||||
return op.emitOpError("expected body to have a single index argument for "
|
||||
"the induction variable");
|
||||
if (!body->getArgument(0).getType().isIndex())
|
||||
return op.emitOpError(
|
||||
"expected body first argument to be an index argument for "
|
||||
"the induction variable");
|
||||
|
||||
auto opNumResults = op.getNumResults();
|
||||
if (opNumResults == 0)
|
||||
return success();
|
||||
// If ForOp defines values, check that the number and types of
|
||||
// the defined values match ForOp initial iter operands and backedge
|
||||
// basic block arguments.
|
||||
if (op.getNumIterOperands() != opNumResults)
|
||||
return op.emitOpError(
|
||||
"mismatch in number of loop-carried values and defined values");
|
||||
if (op.getNumRegionIterArgs() != opNumResults)
|
||||
return op.emitOpError(
|
||||
"mismatch in number of basic block args and defined values");
|
||||
auto iterOperands = op.getIterOperands();
|
||||
auto iterArgs = op.getRegionIterArgs();
|
||||
auto opResults = op.getResults();
|
||||
unsigned i = 0;
|
||||
for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
|
||||
if (std::get<0>(e).getType() != std::get<2>(e).getType())
|
||||
return op.emitOpError() << "types mismatch between " << i
|
||||
<< "th iter operand and defined value";
|
||||
if (std::get<1>(e).getType() != std::get<2>(e).getType())
|
||||
return op.emitOpError() << "types mismatch between " << i
|
||||
<< "th iter region arg and defined value";
|
||||
|
||||
i++;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ForOp op) {
|
||||
bool printBlockTerminators = false;
|
||||
p << op.getOperationName() << " " << op.getInductionVar() << " = "
|
||||
<< op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
|
||||
|
||||
if (op.hasIterOperands()) {
|
||||
p << " iter_args(";
|
||||
auto regionArgs = op.getRegionIterArgs();
|
||||
auto operands = op.getIterOperands();
|
||||
|
||||
mlir::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
|
||||
p << std::get<0>(it) << " = " << std::get<1>(it);
|
||||
});
|
||||
p << ")";
|
||||
p << " -> (" << op.getResultTypes() << ")";
|
||||
printBlockTerminators = true;
|
||||
}
|
||||
p.printRegion(op.region(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
/*printBlockTerminators=*/printBlockTerminators);
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
|
@ -108,9 +150,34 @@ static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) {
|
|||
parser.resolveOperand(step, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
// Parse the optional initial iteration arguments.
|
||||
SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
|
||||
SmallVector<Type, 4> argTypes;
|
||||
regionArgs.push_back(inductionVariable);
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
|
||||
// Parse assignment list and results type list.
|
||||
if (parser.parseAssignmentList(regionArgs, operands) ||
|
||||
parser.parseArrowTypeList(result.types))
|
||||
return failure();
|
||||
// Resolve input operands.
|
||||
for (auto operand_type : llvm::zip(operands, result.types))
|
||||
if (parser.resolveOperand(std::get<0>(operand_type),
|
||||
std::get<1>(operand_type), result.operands))
|
||||
return failure();
|
||||
}
|
||||
// Induction variable.
|
||||
argTypes.push_back(indexType);
|
||||
// Loop carried variables
|
||||
argTypes.append(result.types.begin(), result.types.end());
|
||||
// Parse the body region.
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, inductionVariable, indexType))
|
||||
if (regionArgs.size() != argTypes.size())
|
||||
return parser.emitError(
|
||||
parser.getNameLoc(),
|
||||
"mismatch in number of loop-carried values and defined values");
|
||||
|
||||
if (parser.parseRegion(*body, regionArgs, argTypes))
|
||||
return failure();
|
||||
|
||||
ForOp::ensureTerminator(*body, builder, result.location);
|
||||
|
@ -168,6 +235,9 @@ static LogicalResult verify(IfOp op) {
|
|||
return op.emitOpError(
|
||||
"requires that child entry blocks have no arguments");
|
||||
}
|
||||
if (op.getNumResults() != 0 && op.elseRegion().empty())
|
||||
return op.emitOpError("must have an else block if defining values");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -183,7 +253,9 @@ static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
|
|||
if (parser.parseOperand(cond) ||
|
||||
parser.resolveOperand(cond, i1Type, result.operands))
|
||||
return failure();
|
||||
|
||||
// Parse optional results type list.
|
||||
if (parser.parseOptionalArrowTypeList(result.types))
|
||||
return failure();
|
||||
// Parse the 'then' region.
|
||||
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
||||
return failure();
|
||||
|
@ -199,15 +271,21 @@ static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
|
|||
// Parse the optional attribute list.
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, IfOp op) {
|
||||
bool printBlockTerminators = false;
|
||||
|
||||
p << IfOp::getOperationName() << " " << op.condition();
|
||||
if (!op.results().empty()) {
|
||||
p << " -> (" << op.getResultTypes() << ")";
|
||||
// Print yield explicitly if the op defines values.
|
||||
printBlockTerminators = true;
|
||||
}
|
||||
p.printRegion(op.thenRegion(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
/*printBlockTerminators=*/printBlockTerminators);
|
||||
|
||||
// Print the 'else' regions if it exists and has a block.
|
||||
auto &elseRegion = op.elseRegion();
|
||||
|
@ -215,7 +293,7 @@ static void print(OpAsmPrinter &p, IfOp op) {
|
|||
p << " else";
|
||||
p.printRegion(elseRegion,
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
/*printBlockTerminators=*/printBlockTerminators);
|
||||
}
|
||||
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
|
@ -434,6 +512,54 @@ static LogicalResult verify(ReduceReturnOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static LogicalResult verify(YieldOp op) {
|
||||
auto parentOp = op.getParentOp();
|
||||
auto results = parentOp->getResults();
|
||||
auto operands = op.getOperands();
|
||||
|
||||
if (isa<IfOp>(parentOp) || isa<ForOp>(parentOp)) {
|
||||
if (parentOp->getNumResults() != op.getNumOperands())
|
||||
return op.emitOpError() << "parent of yield must have same number of "
|
||||
"results as the yield operands";
|
||||
for (auto e : llvm::zip(results, operands)) {
|
||||
if (std::get<0>(e).getType() != std::get<1>(e).getType())
|
||||
return op.emitOpError()
|
||||
<< "types mismatch between yield op and its parent";
|
||||
}
|
||||
} else if (isa<ParallelOp>(parentOp)) {
|
||||
if (op.getNumOperands() != 0)
|
||||
return op.emitOpError()
|
||||
<< "yield inside loop.parallel is not allowed to have operands";
|
||||
} else {
|
||||
return op.emitOpError()
|
||||
<< "yield only terminates If, For or Parallel regions";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
SmallVector<Type, 4> types;
|
||||
llvm::SMLoc loc = parser.getCurrentLocation();
|
||||
// Parse variadic operands list, their types, and resolve operands to SSA
|
||||
// values.
|
||||
if (parser.parseOperandList(operands) ||
|
||||
parser.parseOptionalColonTypeList(types) ||
|
||||
parser.resolveOperands(operands, types, loc, result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, YieldOp op) {
|
||||
p << op.getOperationName();
|
||||
if (op.getNumOperands() != 0)
|
||||
p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -4432,6 +4432,13 @@ public:
|
|||
return failure(!(result = parser.parseType()));
|
||||
}
|
||||
|
||||
/// Parse an arrow followed by a type list.
|
||||
ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override {
|
||||
if (parseArrow() || parser.parseFunctionResultTypes(result))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse an optional arrow followed by a type list.
|
||||
ParseResult
|
||||
parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override {
|
||||
|
@ -4462,6 +4469,26 @@ public:
|
|||
return parser.parseTypeListNoParens(result);
|
||||
}
|
||||
|
||||
/// Parse a list of assignments of the form
|
||||
/// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
|
||||
/// The list must contain at least one entry
|
||||
ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
|
||||
SmallVectorImpl<OperandType> &rhs) {
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
OperandType regionArg, operand;
|
||||
Type type;
|
||||
if (parseRegionArgument(regionArg) || parseEqual() ||
|
||||
parseOperand(operand))
|
||||
return failure();
|
||||
lhs.push_back(regionArg);
|
||||
rhs.push_back(operand);
|
||||
return success();
|
||||
};
|
||||
if (parseLParen())
|
||||
return failure();
|
||||
return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The source location of the operation name.
|
||||
SMLoc nameLoc;
|
||||
|
|
|
@ -246,12 +246,12 @@ module {
|
|||
%19 = load %16[%arg5, %arg6] : memref<?x?xf32, #map2>
|
||||
%20 = addf %17, %18 : f32
|
||||
store %20, %16[%arg5, %arg6] : memref<?x?xf32, #map2>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
} { mapping = [
|
||||
{processor = 3, map = #map3, bound = #map3},
|
||||
{processor = 4, map = #map3, bound = #map3}
|
||||
] }
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
} { mapping = [
|
||||
{processor = 0, map = #map3, bound = #map3},
|
||||
{processor = 1, map = #map3, bound = #map3}
|
||||
|
|
|
@ -27,7 +27,7 @@ func @linalg_generic_sum(%lhs: memref<2x2xf32>,
|
|||
// CHECK: %[[SUM_ELEM:.*]] = load %[[SUM]][%[[I]], %[[J]]]
|
||||
// CHECK: %[[SUM:.*]] = addf %[[LHS_ELEM]], %[[RHS_ELEM]] : f32
|
||||
// CHECK: store %[[SUM]], %{{.*}}[%[[I]], %[[J]]]
|
||||
// CHECK: "loop.terminator"() : () -> ()
|
||||
// CHECK: loop.yield
|
||||
|
||||
// -----
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ func @loop_for_step_positive(%arg0: index) {
|
|||
%c0 = constant 0 : index
|
||||
"loop.for"(%arg0, %arg0, %c0) ({
|
||||
^bb0(%arg1: index):
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}) : (index, index, index) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -39,8 +39,8 @@ func @loop_for_step_positive(%arg0: index) {
|
|||
func @loop_for_one_region(%arg0: index) {
|
||||
// expected-error@+1 {{incorrect number of regions: expected 1 but found 2}}
|
||||
"loop.for"(%arg0, %arg0, %arg0) (
|
||||
{"loop.terminator"() : () -> ()},
|
||||
{"loop.terminator"() : () -> ()}
|
||||
{loop.yield},
|
||||
{loop.yield}
|
||||
) : (index, index, index) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -52,9 +52,9 @@ func @loop_for_single_block(%arg0: index) {
|
|||
"loop.for"(%arg0, %arg0, %arg0) (
|
||||
{
|
||||
^bb1:
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
^bb2:
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
) : (index, index, index) -> ()
|
||||
return
|
||||
|
@ -63,11 +63,11 @@ func @loop_for_single_block(%arg0: index) {
|
|||
// -----
|
||||
|
||||
func @loop_for_single_index_argument(%arg0: index) {
|
||||
// expected-error@+1 {{expected body to have a single index argument for the induction variable}}
|
||||
// expected-error@+1 {{op expected body first argument to be an index argument for the induction variable}}
|
||||
"loop.for"(%arg0, %arg0, %arg0) (
|
||||
{
|
||||
^bb0(%i0 : f32):
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
) : (index, index, index) -> ()
|
||||
return
|
||||
|
@ -95,9 +95,9 @@ func @loop_if_not_one_block_per_region(%arg0: i1) {
|
|||
// expected-error@+1 {{expects region #0 to have 0 or 1 blocks}}
|
||||
"loop.if"(%arg0) ({
|
||||
^bb0:
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
^bb1:
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}, {}): (i1) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ func @loop_if_illegal_block_argument(%arg0: i1) {
|
|||
// expected-error@+1 {{requires that child entry blocks have no arguments}}
|
||||
"loop.if"(%arg0) ({
|
||||
^bb0(%0 : index):
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}, {}): (i1) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -130,7 +130,7 @@ func @parallel_body_arguments_wrong_type(
|
|||
// expected-error@+1 {{'loop.parallel' op expects arguments for the induction variable to be of index type}}
|
||||
"loop.parallel"(%arg0, %arg1, %arg2) ({
|
||||
^bb0(%i0: f32):
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}): (index, index, index) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -142,7 +142,7 @@ func @parallel_body_wrong_number_of_arguments(
|
|||
// expected-error@+1 {{'loop.parallel' op expects the same number of induction variables as bound and step values}}
|
||||
"loop.parallel"(%arg0, %arg1, %arg2) ({
|
||||
^bb0(%i0: index, %i1: index):
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}): (index, index, index) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -265,7 +265,7 @@ func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) {
|
|||
// expected-error@+1 {{the block inside reduce should be terminated with a 'loop.reduce.return' op}}
|
||||
loop.reduce(%arg1) {
|
||||
^bb0(%lhs : f32, %rhs : f32):
|
||||
"loop.terminator"(): () -> ()
|
||||
loop.yield
|
||||
} : f32
|
||||
} : f32
|
||||
return
|
||||
|
@ -294,3 +294,87 @@ func @reduceReturn_not_inside_reduce(%arg0 : f32) {
|
|||
}): () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @std_if_incorrect_yield(%arg0: i1, %arg1: f32)
|
||||
{
|
||||
%x, %y = loop.if %arg0 -> (f32, f32) {
|
||||
%0 = addf %arg1, %arg1 : f32
|
||||
// expected-error@+1 {{parent of yield must have same number of results as the yield operands}}
|
||||
loop.yield %0 : f32
|
||||
} else {
|
||||
%0 = subf %arg1, %arg1 : f32
|
||||
loop.yield %0 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @std_if_missing_else(%arg0: i1, %arg1: f32)
|
||||
{
|
||||
// expected-error@+1 {{must have an else block if defining values}}
|
||||
%x = loop.if %arg0 -> (f32) {
|
||||
%0 = addf %arg1, %arg1 : f32
|
||||
loop.yield %0 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @std_for_operands_mismatch(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%s0 = constant 0.0 : f32
|
||||
%t0 = constant 1 : i32
|
||||
// expected-error@+1 {{mismatch in number of loop-carried values and defined values}}
|
||||
%result1:3 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0) -> (f32, i32, f32) {
|
||||
%sn = addf %si, %si : f32
|
||||
%tn = addi %ti, %ti : i32
|
||||
loop.yield %sn, %tn, %sn : f32, i32, f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @std_for_operands_mismatch_2(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%s0 = constant 0.0 : f32
|
||||
%t0 = constant 1 : i32
|
||||
%u0 = constant 1.0 : f32
|
||||
// expected-error@+1 {{mismatch in number of loop-carried values and defined values}}
|
||||
%result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0, %ui = %u0) -> (f32, i32) {
|
||||
%sn = addf %si, %si : f32
|
||||
%tn = addi %ti, %ti : i32
|
||||
%un = subf %ui, %ui : f32
|
||||
loop.yield %sn, %tn, %un : f32, i32, f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
// expected-note@+1 {{prior use here}}
|
||||
%s0 = constant 0.0 : f32
|
||||
%t0 = constant 1.0 : f32
|
||||
// expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
|
||||
%result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0) -> (i32, i32) {
|
||||
%sn = addf %si, %si : i32
|
||||
%tn = addf %ti, %ti : i32
|
||||
loop.yield %sn, %tn : i32, i32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @parallel_invalid_yield(
|
||||
%arg0: index, %arg1: index, %arg2: index) {
|
||||
loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
|
||||
%c0 = constant 1.0 : f32
|
||||
// expected-error@+1 {{yield inside loop.parallel is not allowed to have operands}}
|
||||
loop.yield %c0 : f32
|
||||
}
|
||||
return
|
||||
}
|
|
@ -90,6 +90,134 @@ func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
|
|||
// CHECK-NEXT: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: loop.reduce.return %[[RES]] : f32
|
||||
// CHECK-NEXT: } : f32
|
||||
// CHECK-NEXT: "loop.terminator"() : () -> ()
|
||||
// CHECK-NEXT: loop.yield
|
||||
// CHECK-NEXT: } : f32
|
||||
// CHECK-NEXT: "loop.terminator"() : () -> ()
|
||||
// CHECK-NEXT: loop.yield
|
||||
|
||||
func @parallel_explicit_yield(
|
||||
%arg0: index, %arg1: index, %arg2: index) {
|
||||
loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
|
||||
loop.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @parallel_explicit_yield(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]:
|
||||
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]:
|
||||
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]:
|
||||
// CHECK-NEXT: loop.parallel (%{{.*}}) = (%[[ARG0]]) to (%[[ARG1]]) step (%[[ARG2]])
|
||||
// CHECK-NEXT: loop.yield
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @std_if_yield(%arg0: i1, %arg1: f32)
|
||||
{
|
||||
%x, %y = loop.if %arg0 -> (f32, f32) {
|
||||
%0 = addf %arg1, %arg1 : f32
|
||||
%1 = subf %arg1, %arg1 : f32
|
||||
loop.yield %0, %1 : f32, f32
|
||||
} else {
|
||||
%0 = subf %arg1, %arg1 : f32
|
||||
%1 = addf %arg1, %arg1 : f32
|
||||
loop.yield %0, %1 : f32, f32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @std_if_yield(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]:
|
||||
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]:
|
||||
// CHECK-NEXT: %{{.*}}:2 = loop.if %[[ARG0]] -> (f32, f32) {
|
||||
// CHECK-NEXT: %[[T1:.*]] = addf %[[ARG1]], %[[ARG1]]
|
||||
// CHECK-NEXT: %[[T2:.*]] = subf %[[ARG1]], %[[ARG1]]
|
||||
// CHECK-NEXT: loop.yield %[[T1]], %[[T2]] : f32, f32
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[T3:.*]] = subf %[[ARG1]], %[[ARG1]]
|
||||
// CHECK-NEXT: %[[T4:.*]] = addf %[[ARG1]], %[[ARG1]]
|
||||
// CHECK-NEXT: loop.yield %[[T3]], %[[T4]] : f32, f32
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @std_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%s0 = constant 0.0 : f32
|
||||
%result = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0) -> (f32) {
|
||||
%sn = addf %si, %si : f32
|
||||
loop.yield %sn : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @std_for_yield(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]:
|
||||
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]:
|
||||
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]:
|
||||
// CHECK-NEXT: %[[INIT:.*]] = constant
|
||||
// CHECK-NEXT: %{{.*}} = loop.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
|
||||
// CHECK-SAME: iter_args(%[[ITER:.*]] = %[[INIT]]) -> (f32) {
|
||||
// CHECK-NEXT: %[[NEXT:.*]] = addf %[[ITER]], %[[ITER]] : f32
|
||||
// CHECK-NEXT: loop.yield %[[NEXT]] : f32
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
||||
func @std_for_yield_multi(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%s0 = constant 0.0 : f32
|
||||
%t0 = constant 1 : i32
|
||||
%u0 = constant 1.0 : f32
|
||||
%result1:3 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0, %ui = %u0) -> (f32, i32, f32) {
|
||||
%sn = addf %si, %si : f32
|
||||
%tn = addi %ti, %ti : i32
|
||||
%un = subf %ui, %ui : f32
|
||||
loop.yield %sn, %tn, %un : f32, i32, f32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @std_for_yield_multi(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]:
|
||||
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]:
|
||||
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]:
|
||||
// CHECK-NEXT: %[[INIT1:.*]] = constant
|
||||
// CHECK-NEXT: %[[INIT2:.*]] = constant
|
||||
// CHECK-NEXT: %[[INIT3:.*]] = constant
|
||||
// CHECK-NEXT: %{{.*}}:3 = loop.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
|
||||
// CHECK-SAME: iter_args(%[[ITER1:.*]] = %[[INIT1]], %[[ITER2:.*]] = %[[INIT2]], %[[ITER3:.*]] = %[[INIT3]]) -> (f32, i32, f32) {
|
||||
// CHECK-NEXT: %[[NEXT1:.*]] = addf %[[ITER1]], %[[ITER1]] : f32
|
||||
// CHECK-NEXT: %[[NEXT2:.*]] = addi %[[ITER2]], %[[ITER2]] : i32
|
||||
// CHECK-NEXT: %[[NEXT3:.*]] = subf %[[ITER3]], %[[ITER3]] : f32
|
||||
// CHECK-NEXT: loop.yield %[[NEXT1]], %[[NEXT2]], %[[NEXT3]] : f32, i32, f32
|
||||
|
||||
|
||||
func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index) -> (f32) {
|
||||
%sum_0 = constant 0.0 : f32
|
||||
%c0 = constant 0.0 : f32
|
||||
%sum = loop.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_0) -> (f32) {
|
||||
%t = load %buffer[%iv] : memref<1024xf32>
|
||||
%cond = cmpf "ugt", %t, %c0 : f32
|
||||
%sum_next = loop.if %cond -> (f32) {
|
||||
%new_sum = addf %sum_iter, %t : f32
|
||||
loop.yield %new_sum : f32
|
||||
} else {
|
||||
loop.yield %sum_iter : f32
|
||||
}
|
||||
loop.yield %sum_next : f32
|
||||
}
|
||||
return %sum : f32
|
||||
}
|
||||
// CHECK-LABEL: func @conditional_reduce(
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]
|
||||
// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]
|
||||
// CHECK-NEXT: %[[INIT:.*]] = constant
|
||||
// CHECK-NEXT: %[[ZERO:.*]] = constant
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = loop.for %[[IV:.*]] = %[[ARG1]] to %[[ARG2]] step %[[ARG3]]
|
||||
// CHECK-SAME: iter_args(%[[ITER:.*]] = %[[INIT]]) -> (f32) {
|
||||
// CHECK-NEXT: %[[T:.*]] = load %[[ARG0]][%[[IV]]]
|
||||
// CHECK-NEXT: %[[COND:.*]] = cmpf "ugt", %[[T]], %[[ZERO]]
|
||||
// CHECK-NEXT: %[[IFRES:.*]] = loop.if %[[COND]] -> (f32) {
|
||||
// CHECK-NEXT: %[[THENRES:.*]] = addf %[[ITER]], %[[T]]
|
||||
// CHECK-NEXT: loop.yield %[[THENRES]] : f32
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: loop.yield %[[ITER]] : f32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: loop.yield %[[IFRES]] : f32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[RESULT]]
|
||||
|
|
|
@ -5,10 +5,10 @@ func @fuse_empty_loops() {
|
|||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ func @fuse_empty_loops() {
|
|||
// CHECK: [[C1:%.*]] = constant 1 : index
|
||||
// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
|
||||
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
|
||||
// CHECK: "loop.terminator"() : () -> ()
|
||||
// CHECK: loop.yield
|
||||
// CHECK: }
|
||||
// CHECK-NOT: loop.parallel
|
||||
|
||||
|
@ -35,14 +35,14 @@ func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
|
|||
%C_elem = load %C[%i, %j] : memref<2x2xf32>
|
||||
%sum_elem = addf %B_elem, %C_elem : f32
|
||||
store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%sum_elem = load %sum[%i, %j] : memref<2x2xf32>
|
||||
%A_elem = load %A[%i, %j] : memref<2x2xf32>
|
||||
%product_elem = mulf %sum_elem, %A_elem : f32
|
||||
store %product_elem, %result[%i, %j] : memref<2x2xf32>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
dealloc %sum : memref<2x2xf32>
|
||||
return
|
||||
|
@ -64,7 +64,7 @@ func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
|
|||
// CHECK: [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]]
|
||||
// CHECK: store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: "loop.terminator"() : () -> ()
|
||||
// CHECK: loop.yield
|
||||
// CHECK: }
|
||||
// CHECK: dealloc [[SUM]]
|
||||
|
||||
|
@ -81,20 +81,20 @@ func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
|
|||
loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
|
||||
%rhs_elem = load %rhs[%i] : memref<100xf32>
|
||||
store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
|
||||
%lhs_elem = load %lhs[%i, %j] : memref<100x10xf32>
|
||||
%broadcast_rhs_elem = load %broadcast_rhs[%i, %j] : memref<100x10xf32>
|
||||
%diff_elem = subf %lhs_elem, %broadcast_rhs_elem : f32
|
||||
store %diff_elem, %diff[%i, %j] : memref<100x10xf32>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
|
||||
%diff_elem = load %diff[%i, %j] : memref<100x10xf32>
|
||||
%exp_elem = exp %diff_elem : f32
|
||||
store %exp_elem, %result[%i, %j] : memref<100x10xf32>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
dealloc %broadcast_rhs : memref<100x10xf32>
|
||||
dealloc %diff : memref<100x10xf32>
|
||||
|
@ -120,7 +120,7 @@ func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
|
|||
// CHECK: [[DIFF_ELEM_:%.*]] = load [[DIFF]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: [[EXP_ELEM:%.*]] = exp [[DIFF_ELEM_]]
|
||||
// CHECK: store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
|
||||
// CHECK: "loop.terminator"() : () -> ()
|
||||
// CHECK: loop.yield
|
||||
// CHECK: }
|
||||
// CHECK: dealloc [[BROADCAST_RHS]]
|
||||
// CHECK: dealloc [[DIFF]]
|
||||
|
@ -133,12 +133,12 @@ func @do_not_fuse_nested_ploop1() {
|
|||
%c1 = constant 1 : index
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
@ -154,13 +154,13 @@ func @do_not_fuse_nested_ploop2() {
|
|||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
@ -176,10 +176,10 @@ func @do_not_fuse_loops_unmatching_num_loops() {
|
|||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i) = (%c0) to (%c2) step (%c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
@ -194,11 +194,11 @@ func @do_not_fuse_loops_with_side_effecting_ops_in_between() {
|
|||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
%buffer = alloc() : memref<2x2xf32>
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
@ -214,10 +214,10 @@ func @do_not_fuse_loops_unmatching_iteration_space() {
|
|||
%c2 = constant 2 : index
|
||||
%c4 = constant 4 : index
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ func @do_not_fuse_unmatching_write_read_patterns(
|
|||
%C_elem = load %C[%i, %j] : memref<2x2xf32>
|
||||
%sum_elem = addf %B_elem, %C_elem : f32
|
||||
store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%k = addi %i, %c1 : index
|
||||
|
@ -247,7 +247,7 @@ func @do_not_fuse_unmatching_write_read_patterns(
|
|||
%A_elem = load %A[%i, %j] : memref<2x2xf32>
|
||||
%product_elem = mulf %sum_elem, %A_elem : f32
|
||||
store %product_elem, %result[%i, %j] : memref<2x2xf32>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
dealloc %common_buf : memref<2x2xf32>
|
||||
return
|
||||
|
@ -269,7 +269,7 @@ func @do_not_fuse_unmatching_read_write_patterns(
|
|||
%C_elem = load %common_buf[%i, %j] : memref<2x2xf32>
|
||||
%sum_elem = addf %B_elem, %C_elem : f32
|
||||
store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%k = addi %i, %c1 : index
|
||||
|
@ -277,7 +277,7 @@ func @do_not_fuse_unmatching_read_write_patterns(
|
|||
%A_elem = load %A[%i, %j] : memref<2x2xf32>
|
||||
%product_elem = mulf %sum_elem, %A_elem : f32
|
||||
store %product_elem, %common_buf[%j, %i] : memref<2x2xf32>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
dealloc %sum : memref<2x2xf32>
|
||||
return
|
||||
|
@ -294,13 +294,13 @@ func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
|
|||
%c1 = constant 1 : index
|
||||
%buffer = alloc() : memref<2x2xf32>
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
|
||||
%A = subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1]
|
||||
: memref<2x2xf32> to memref<?x?xf32, offset: ?, strides:[?, ?]>
|
||||
%A_elem = load %A[%i, %j] : memref<?x?xf32, offset: ?, strides:[?, ?]>
|
||||
"loop.terminator"() : () -> ()
|
||||
loop.yield
|
||||
}
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue