forked from OSchip/llvm-project
Add a standard if op
This CL adds an "std.if" op to represent an if-then-else construct whose condition is an arbitrary value of type i1. This is necessary to lower all the existing examples from affine and linalg to std.for + std.if. This CL introduces the op and adds the relevant positive and negative unit test. Lowering will be done in a separate followup CL. PiperOrigin-RevId: 256649138
This commit is contained in:
parent
471da08e4e
commit
991040478b
|
@ -24,6 +24,7 @@
|
|||
#define MLIR_STANDARDOPS_OPS_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
@ -356,6 +357,12 @@ ParseResult parseDimAndSymbolList(OpAsmParser *parser,
|
|||
SmallVector<Value *, 4> &operands,
|
||||
unsigned &numDims);
|
||||
|
||||
// Insert `std.terminator` at the end of the only region's only block if it does
|
||||
// not have a terminator already. If a new `std.terminator` is inserted, the
|
||||
// location is specified by `loc`. If the region is empty, insert a new block
|
||||
// first.
|
||||
void ensureStdTerminator(Region ®ion, Builder &builder, Location loc);
|
||||
|
||||
/// The "std.for" operation represents a loop nest taking 3 SSA value as
|
||||
/// operands that represent the lower bound, upper bound and step respectively.
|
||||
/// The operation defines an SSA value for its induction variable. It has one
|
||||
|
|
|
@ -609,6 +609,46 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// TODO(ntv): Default generated builder creates IR that does not verify. Atm it
|
||||
// is the responsibility of each caller to call ensureStdTerminator on the
|
||||
// then and else regions.
|
||||
def IfOp : Std_Op<"if"> {
|
||||
let summary = "if-then-else operation";
|
||||
let description = [{
|
||||
The "std.if" operation represents an if-then-else construct for
|
||||
conditionally executing two regions of code. The operand to an if operation
|
||||
is a boolean value. The operation produces no results. For example:
|
||||
|
||||
std.if %b {
|
||||
...
|
||||
} else {
|
||||
...
|
||||
}
|
||||
|
||||
The 'else' block is optional, and may be omitted. For
|
||||
example:
|
||||
|
||||
std.if %b {
|
||||
...
|
||||
}
|
||||
}];
|
||||
let arguments = (ins I1:$condition);
|
||||
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
OpBuilder getThenBodyBuilder() {
|
||||
assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
|
||||
Block &body = thenRegion().front();
|
||||
return OpBuilder(&body, std::prev(body.end()));
|
||||
}
|
||||
OpBuilder getElseBodyBuilder() {
|
||||
assert(!elseRegion().empty() && "Unexpected empty 'else' region.");
|
||||
Block &body = elseRegion().front();
|
||||
return OpBuilder(&body, std::prev(body.end()));
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
|
||||
let summary = "cast between index and integer types";
|
||||
let description = [{
|
||||
|
|
|
@ -1638,8 +1638,9 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
// StdForOp.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Check that if a "block" has a terminator, it is an `TerminatorOp`.
|
||||
static LogicalResult checkHasTerminator(OpState &op, Block &block) {
|
||||
static LogicalResult checkHasStdTerminator(OpState &op, Block &block) {
|
||||
if (block.empty() || isa<StdTerminatorOp>(block.back()))
|
||||
return success();
|
||||
|
||||
|
@ -1650,11 +1651,7 @@ static LogicalResult checkHasTerminator(OpState &op, Block &block) {
|
|||
<< StdTerminatorOp::getOperationName() << "'";
|
||||
}
|
||||
|
||||
// Insert `cf.terminator` at the end of the StdForOp only region's only block
|
||||
// if it does not have a terminator already. If a new `cf.terminator` is
|
||||
// inserted, the location is specified by `loc`. If the region is empty,
|
||||
// insert a new block first.
|
||||
static void ensureTerminator(Region ®ion, Builder &builder, Location loc) {
|
||||
void mlir::ensureStdTerminator(Region ®ion, Builder &builder, Location loc) {
|
||||
impl::ensureRegionTerminator<StdTerminatorOp>(region, builder, loc);
|
||||
}
|
||||
|
||||
|
@ -1665,7 +1662,7 @@ void StdForOp::build(Builder *builder, OperationState *result, Value *lb,
|
|||
Block *body = new Block();
|
||||
body->addArgument(IndexType::get(builder->getContext()));
|
||||
bodyRegion->push_back(body);
|
||||
ensureTerminator(*bodyRegion, *builder, result->location);
|
||||
ensureStdTerminator(*bodyRegion, *builder, result->location);
|
||||
}
|
||||
|
||||
LogicalResult StdForOp::verify() {
|
||||
|
@ -1694,7 +1691,7 @@ LogicalResult StdForOp::verify() {
|
|||
!body->getArgument(0)->getType().isIndex())
|
||||
return emitOpError("expected body to have a single index argument for "
|
||||
"the induction variable");
|
||||
if (failed(checkHasTerminator(*this, *body)))
|
||||
if (failed(checkHasStdTerminator(*this, *body)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
@ -1731,7 +1728,7 @@ ParseResult StdForOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
if (parser->parseRegion(*body, inductionVariable, indexType))
|
||||
return failure();
|
||||
|
||||
ensureTerminator(*body, builder, result->location);
|
||||
ensureStdTerminator(*body, builder, result->location);
|
||||
|
||||
// Parse the optional attribute list.
|
||||
if (parser->parseOptionalAttributeDict(result->attributes))
|
||||
|
@ -1754,6 +1751,81 @@ StdForOp getStdForInductionVarOwner(Value *val) {
|
|||
return dyn_cast_or_null<StdForOp>(containingInst);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(IfOp op) {
|
||||
// Verify that the entry block of each child region does not have arguments.
|
||||
for (auto ®ion : op.getOperation()->getRegions()) {
|
||||
if (region.empty())
|
||||
continue;
|
||||
|
||||
// TODO(riverriddle) We currently do not allow multiple blocks in child
|
||||
// regions.
|
||||
if (std::next(region.begin()) != region.end())
|
||||
return op.emitOpError("expected one block per 'then' or 'else' regions");
|
||||
if (failed(checkHasStdTerminator(op, region.front())))
|
||||
return failure();
|
||||
|
||||
for (auto &b : region)
|
||||
if (b.getNumArguments() != 0)
|
||||
return op.emitOpError(
|
||||
"requires that child entry blocks have no arguments");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseIfOp(OpAsmParser *parser, OperationState *result) {
|
||||
// Create the regions for 'then'.
|
||||
result->regions.reserve(2);
|
||||
Region *thenRegion = result->addRegion();
|
||||
Region *elseRegion = result->addRegion();
|
||||
|
||||
auto &builder = parser->getBuilder();
|
||||
OpAsmParser::OperandType cond;
|
||||
Type i1Type = builder.getIntegerType(1);
|
||||
if (parser->parseOperand(cond) ||
|
||||
parser->resolveOperand(cond, i1Type, result->operands))
|
||||
return failure();
|
||||
|
||||
// Parse the 'then' region.
|
||||
if (parser->parseRegion(*thenRegion, {}, {}))
|
||||
return failure();
|
||||
ensureStdTerminator(*thenRegion, parser->getBuilder(), result->location);
|
||||
|
||||
// If we find an 'else' keyword then parse the 'else' region.
|
||||
if (!parser->parseOptionalKeyword("else")) {
|
||||
if (parser->parseRegion(*elseRegion, {}, {}))
|
||||
return failure();
|
||||
ensureStdTerminator(*elseRegion, parser->getBuilder(), result->location);
|
||||
}
|
||||
|
||||
// Parse the optional attribute list.
|
||||
if (parser->parseOptionalAttributeDict(result->attributes))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, IfOp op) {
|
||||
*p << IfOp::getOperationName() << " " << *op.condition();
|
||||
p->printRegion(op.thenRegion(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
|
||||
// Print the 'else' regions if it exists and has a block.
|
||||
auto &elseRegion = op.elseRegion();
|
||||
if (!elseRegion.empty()) {
|
||||
*p << " else";
|
||||
p->printRegion(elseRegion,
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
}
|
||||
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IndexCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -437,7 +437,7 @@ func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @std_for(%arg0: index, %arg1: index, %arg2: index) {
|
||||
// CHECK-LABEL: func @std_for(
|
||||
// CHECK-NEXT: std.for %i0 = %arg0 to %arg1 step %arg2 {
|
||||
// CHECK-NEXT: std.for %i1 = %arg0 to %arg1 step %arg2 {
|
||||
// CHECK-NEXT: %0 = cmpi "slt", %i0, %i1 : index
|
||||
|
@ -445,3 +445,27 @@ func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||
// CHECK-NEXT: %2 = cmpi "sge", %i0, %i1 : index
|
||||
// CHECK-NEXT: %3 = select %2, %i0, %i1 : index
|
||||
// CHECK-NEXT: std.for %i2 = %1 to %3 step %i1 {
|
||||
|
||||
func @std_if(%arg0: i1, %arg1: f32) {
|
||||
std.if %arg0 {
|
||||
%0 = addf %arg1, %arg1 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @std_if(
|
||||
// CHECK-NEXT: std.if %arg0 {
|
||||
// CHECK-NEXT: %{{.*}} = addf %arg1, %arg1 : f32
|
||||
|
||||
func @std_if_else(%arg0: i1, %arg1: f32) {
|
||||
std.if %arg0 {
|
||||
%0 = addf %arg1, %arg1 : f32
|
||||
} else {
|
||||
%1 = addf %arg1, %arg1 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @std_if_else(
|
||||
// CHECK-NEXT: std.if %arg0 {
|
||||
// CHECK-NEXT: %{{.*}} = addf %arg1, %arg1 : f32
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %{{.*}} = addf %arg1, %arg1 : f32
|
||||
|
|
|
@ -770,4 +770,44 @@ func @std_for_single_index_argument(%arg0: index) {
|
|||
}
|
||||
) : (index, index, index) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @std_if_not_i1(%arg0: index) {
|
||||
// expected-error@+1 {{operand #0 must be 1-bit integer}}
|
||||
"std.if"(%arg0) : (index) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @std_if_more_than_2_regions(%arg0: i1) {
|
||||
// expected-error@+1 {{op has incorrect number of regions: expected 2}}
|
||||
"std.if"(%arg0) ({}, {}, {}): (i1) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @std_if_not_one_block_per_region(%arg0: i1) {
|
||||
// expected-error@+1 {{region #0 ('thenRegion') failed to verify constraint: region with 1 blocks}}
|
||||
"std.if"(%arg0) ({
|
||||
^bb0:
|
||||
"std.terminator"() : () -> ()
|
||||
^bb1:
|
||||
"std.terminator"() : () -> ()
|
||||
}, {}): (i1) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @std_if_illegal_block_argument(%arg0: i1) {
|
||||
// expected-error@+1 {{requires that child entry blocks have no arguments}}
|
||||
"std.if"(%arg0) ({
|
||||
^bb0(%0 : index):
|
||||
"std.terminator"() : () -> ()
|
||||
}, {}): (i1) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue