Add a standard if op

This CL adds an "std.if" op to represent an if-then-else construct whose condition is an arbitrary value of type i1.
This is necessary to lower all the existing examples from affine and linalg to std.for + std.if.

This CL introduces the op and adds the relevant positive and negative unit test. Lowering will be done in a separate followup CL.

PiperOrigin-RevId: 256649138
This commit is contained in:
Nicolas Vasilache 2019-07-05 03:34:49 -07:00 committed by A. Unique TensorFlower
parent 471da08e4e
commit 991040478b
5 changed files with 194 additions and 11 deletions

View File

@ -24,6 +24,7 @@
#define MLIR_STANDARDOPS_OPS_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
@ -356,6 +357,12 @@ ParseResult parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims);
// Insert `std.terminator` at the end of the only region's only block if it does
// not have a terminator already. If a new `std.terminator` is inserted, the
// location is specified by `loc`. If the region is empty, insert a new block
// first.
void ensureStdTerminator(Region &region, Builder &builder, Location loc);
/// The "std.for" operation represents a loop nest taking 3 SSA value as
/// operands that represent the lower bound, upper bound and step respectively.
/// The operation defines an SSA value for its induction variable. It has one

View File

@ -609,6 +609,46 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
let hasFolder = 1;
}
// TODO(ntv): Default generated builder creates IR that does not verify. Atm it
// is the responsibility of each caller to call ensureStdTerminator on the
// then and else regions.
def IfOp : Std_Op<"if"> {
let summary = "if-then-else operation";
let description = [{
The "std.if" operation represents an if-then-else construct for
conditionally executing two regions of code. The operand to an if operation
is a boolean value. The operation produces no results. For example:
std.if %b {
...
} else {
...
}
The 'else' block is optional, and may be omitted. For
example:
std.if %b {
...
}
}];
let arguments = (ins I1:$condition);
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
let extraClassDeclaration = [{
OpBuilder getThenBodyBuilder() {
assert(!thenRegion().empty() && "Unexpected empty 'then' region.");
Block &body = thenRegion().front();
return OpBuilder(&body, std::prev(body.end()));
}
OpBuilder getElseBodyBuilder() {
assert(!elseRegion().empty() && "Unexpected empty 'else' region.");
Block &body = elseRegion().front();
return OpBuilder(&body, std::prev(body.end()));
}
}];
}
def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
let summary = "cast between index and integer types";
let description = [{

View File

@ -1638,8 +1638,9 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
////////////////////////////////////////////////////////////////////////////////
// StdForOp.
////////////////////////////////////////////////////////////////////////////////
// Check that if a "block" has a terminator, it is an `TerminatorOp`.
static LogicalResult checkHasTerminator(OpState &op, Block &block) {
static LogicalResult checkHasStdTerminator(OpState &op, Block &block) {
if (block.empty() || isa<StdTerminatorOp>(block.back()))
return success();
@ -1650,11 +1651,7 @@ static LogicalResult checkHasTerminator(OpState &op, Block &block) {
<< StdTerminatorOp::getOperationName() << "'";
}
// Insert `cf.terminator` at the end of the StdForOp only region's only block
// if it does not have a terminator already. If a new `cf.terminator` is
// inserted, the location is specified by `loc`. If the region is empty,
// insert a new block first.
static void ensureTerminator(Region &region, Builder &builder, Location loc) {
void mlir::ensureStdTerminator(Region &region, Builder &builder, Location loc) {
impl::ensureRegionTerminator<StdTerminatorOp>(region, builder, loc);
}
@ -1665,7 +1662,7 @@ void StdForOp::build(Builder *builder, OperationState *result, Value *lb,
Block *body = new Block();
body->addArgument(IndexType::get(builder->getContext()));
bodyRegion->push_back(body);
ensureTerminator(*bodyRegion, *builder, result->location);
ensureStdTerminator(*bodyRegion, *builder, result->location);
}
LogicalResult StdForOp::verify() {
@ -1694,7 +1691,7 @@ LogicalResult StdForOp::verify() {
!body->getArgument(0)->getType().isIndex())
return emitOpError("expected body to have a single index argument for "
"the induction variable");
if (failed(checkHasTerminator(*this, *body)))
if (failed(checkHasStdTerminator(*this, *body)))
return failure();
return success();
}
@ -1731,7 +1728,7 @@ ParseResult StdForOp::parse(OpAsmParser *parser, OperationState *result) {
if (parser->parseRegion(*body, inductionVariable, indexType))
return failure();
ensureTerminator(*body, builder, result->location);
ensureStdTerminator(*body, builder, result->location);
// Parse the optional attribute list.
if (parser->parseOptionalAttributeDict(result->attributes))
@ -1754,6 +1751,81 @@ StdForOp getStdForInductionVarOwner(Value *val) {
return dyn_cast_or_null<StdForOp>(containingInst);
}
//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(IfOp op) {
// Verify that the entry block of each child region does not have arguments.
for (auto &region : op.getOperation()->getRegions()) {
if (region.empty())
continue;
// TODO(riverriddle) We currently do not allow multiple blocks in child
// regions.
if (std::next(region.begin()) != region.end())
return op.emitOpError("expected one block per 'then' or 'else' regions");
if (failed(checkHasStdTerminator(op, region.front())))
return failure();
for (auto &b : region)
if (b.getNumArguments() != 0)
return op.emitOpError(
"requires that child entry blocks have no arguments");
}
return success();
}
static ParseResult parseIfOp(OpAsmParser *parser, OperationState *result) {
// Create the regions for 'then'.
result->regions.reserve(2);
Region *thenRegion = result->addRegion();
Region *elseRegion = result->addRegion();
auto &builder = parser->getBuilder();
OpAsmParser::OperandType cond;
Type i1Type = builder.getIntegerType(1);
if (parser->parseOperand(cond) ||
parser->resolveOperand(cond, i1Type, result->operands))
return failure();
// Parse the 'then' region.
if (parser->parseRegion(*thenRegion, {}, {}))
return failure();
ensureStdTerminator(*thenRegion, parser->getBuilder(), result->location);
// If we find an 'else' keyword then parse the 'else' region.
if (!parser->parseOptionalKeyword("else")) {
if (parser->parseRegion(*elseRegion, {}, {}))
return failure();
ensureStdTerminator(*elseRegion, parser->getBuilder(), result->location);
}
// Parse the optional attribute list.
if (parser->parseOptionalAttributeDict(result->attributes))
return failure();
return success();
}
static void print(OpAsmPrinter *p, IfOp op) {
*p << IfOp::getOperationName() << " " << *op.condition();
p->printRegion(op.thenRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
// Print the 'else' regions if it exists and has a block.
auto &elseRegion = op.elseRegion();
if (!elseRegion.empty()) {
*p << " else";
p->printRegion(elseRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
p->printOptionalAttrDict(op.getAttrs());
}
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//

View File

@ -437,7 +437,7 @@ func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
}
return
}
// CHECK-LABEL: func @std_for(%arg0: index, %arg1: index, %arg2: index) {
// CHECK-LABEL: func @std_for(
// CHECK-NEXT: std.for %i0 = %arg0 to %arg1 step %arg2 {
// CHECK-NEXT: std.for %i1 = %arg0 to %arg1 step %arg2 {
// CHECK-NEXT: %0 = cmpi "slt", %i0, %i1 : index
@ -445,3 +445,27 @@ func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK-NEXT: %2 = cmpi "sge", %i0, %i1 : index
// CHECK-NEXT: %3 = select %2, %i0, %i1 : index
// CHECK-NEXT: std.for %i2 = %1 to %3 step %i1 {
func @std_if(%arg0: i1, %arg1: f32) {
std.if %arg0 {
%0 = addf %arg1, %arg1 : f32
}
return
}
// CHECK-LABEL: func @std_if(
// CHECK-NEXT: std.if %arg0 {
// CHECK-NEXT: %{{.*}} = addf %arg1, %arg1 : f32
func @std_if_else(%arg0: i1, %arg1: f32) {
std.if %arg0 {
%0 = addf %arg1, %arg1 : f32
} else {
%1 = addf %arg1, %arg1 : f32
}
return
}
// CHECK-LABEL: func @std_if_else(
// CHECK-NEXT: std.if %arg0 {
// CHECK-NEXT: %{{.*}} = addf %arg1, %arg1 : f32
// CHECK-NEXT: } else {
// CHECK-NEXT: %{{.*}} = addf %arg1, %arg1 : f32

View File

@ -770,4 +770,44 @@ func @std_for_single_index_argument(%arg0: index) {
}
) : (index, index, index) -> ()
return
}
}
// -----
func @std_if_not_i1(%arg0: index) {
// expected-error@+1 {{operand #0 must be 1-bit integer}}
"std.if"(%arg0) : (index) -> ()
return
}
// -----
func @std_if_more_than_2_regions(%arg0: i1) {
// expected-error@+1 {{op has incorrect number of regions: expected 2}}
"std.if"(%arg0) ({}, {}, {}): (i1) -> ()
return
}
// -----
func @std_if_not_one_block_per_region(%arg0: i1) {
// expected-error@+1 {{region #0 ('thenRegion') failed to verify constraint: region with 1 blocks}}
"std.if"(%arg0) ({
^bb0:
"std.terminator"() : () -> ()
^bb1:
"std.terminator"() : () -> ()
}, {}): (i1) -> ()
return
}
// -----
func @std_if_illegal_block_argument(%arg0: i1) {
// expected-error@+1 {{requires that child entry blocks have no arguments}}
"std.if"(%arg0) ({
^bb0(%0 : index):
"std.terminator"() : () -> ()
}, {}): (i1) -> ()
return
}