StandardOps: introduce 'select'.

The semantics of 'select' is conventional: return the second operand if the
first operand is true (1 : i1) and the third operand otherwise.  It is
applicable to vectors and tensors element-wise, similarly to LLVM instruction.
This operation is necessary to implement min/max to lower 'for' loops with
complex bounds to CFG functions and to support ternary operations in ML
functions.  It is preferred to first-class min/max because of its simplicity,
e.g. it is not concered with signedness.

PiperOrigin-RevId: 223160860
This commit is contained in:
Alex Zinenko 2018-11-28 07:08:55 -08:00 committed by jpienaar
parent e7f43c8361
commit a3fb6d0da3
6 changed files with 224 additions and 6 deletions

View File

@ -1900,11 +1900,45 @@ TODO: In the distant future, this will accept
optional attributes for fast math, contraction, rounding mode, and other
controls.
#### 'select' operation {#'select'-operation}
Syntax:
``` {.ebnf}
operation ::= ssa-id `=` `select` ssa-use, ssa-use, ssa-use `:` type
```
Examples:
```mlir {.mlir}
// Short-hand notation of scalar selection.
%x = select %cond, %true, %false : i32
// Long-hand notation of the same operation.
%x = "select"(%cond, %true, %false) : (i1, i32, i32) -> i32
// Vector selection is element-wise
%vx = "select"(%vcond, %vtrue, %vfalse)
: (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32>
```
The `select` operation chooses one value based on a binary condition supplied as
its first operand. If the value of the first operand is `1`, the second operand
is chosen, otherwise the third operand is chosen. The second and the third
operand must have the same type.
The operation applies to vectors and tensors elementwise given the _shape_ of
all operands is identical. The choice is made for each element individually
based on the value at the same position as the element in the condition operand.
The `select` operation combined with [`cmpi`](#'cmpi'-operation) can be used to
implement `min` and `max` with signed or unsigned comparison semantics.
#### 'tensor_cast' operation {#'tensor_cast'-operation}
Syntax:
```mlir {.mlir}
``` {.ebnf}
operation ::= ssa-id `=` `tensor_cast` ssa-use `:` type `to` type
```

View File

@ -294,12 +294,39 @@ readability by humans, short-hand notation accepts string literals that are
mapped to the underlying integer values: `cmpi "eq", %lhs, %rhs` better implies
integer equality comparison than `cmpi 0, %lhs, %rhs` where it is unclear what
gets compared to what else. This syntactic sugar is possible thanks to parser
logic redifinitions for short-hand notation of non-builtin operations.
logic redefinitions for short-hand notation of non-builtin operations.
Supporting it in the full notation would have required changing how the main
parsing algorithm works and may have unexpected repercussions. While it had been
possible to store the predicate as string attribute, it would have rendered
impossible to implement switching logic based on the comparison kind and made
attribute validity checks (one out of ten possibile kinds) more complex.
attribute validity checks (one out of ten possible kinds) more complex.
### 'select' operation to implement min/max {#select-operation}
Although `min` and `max` operations are likely to occur as a result of
transforming affine loops in ML functions, we did not make them first-class
operations. Instead, we provide the `select` operation that can be combined with
`cmpi` to implement the minimum and maximum computation. Although they now
require two operations, they are likely to be emitted automatically during the
transformation inside MLIR. On the other hand, there are multiple benefits of
introducing `select`: standalone min/max would concern themselves with the
signedness of the comparison, already taken into account by `cmpi`; `select` can
support floats transparently if used after a float-comparison operation; the
lower-level targets provide `select`-like instructions making the translation
trivial.
This operation could have been implemented with additional control flow: `%r =
select %cond, %t, %f` is equivalent to
```mlir
bb0:
br_cond %cond, bb1(%t), bb1(%f)
bb1(%r):
```
However, this control flow granularity is not available in the ML functions
where min/max, and thus `select`, are likely to appear. In addition, simpler
control flow may be beneficial for optimization in general.
### Quantized integer operations {#quantized-integer-operations}

View File

@ -26,7 +26,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "third_party/llvm/llvm/projects/google-mlir/include/mlir/IR/OpDefinition.h"
namespace mlir {
class Builder;
@ -638,6 +637,32 @@ private:
explicit MulIOp(const Operation *state) : BinaryOp(state) {}
};
class SelectOp : public Op<SelectOp, OpTrait::NOperands<3>::Impl,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
static StringRef getOperationName() { return "select"; }
static void build(Builder *builder, OperationState *result,
SSAValue *condition, SSAValue *trueValue,
SSAValue *falseValue);
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
SSAValue *getCondition() { return getOperand(0); }
const SSAValue *getCondition() const { return getOperand(0); }
SSAValue *getTrueValue() { return getOperand(1); }
const SSAValue *getTrueValue() const { return getOperand(1); }
SSAValue *getFalseValue() { return getOperand(2); }
const SSAValue *getFalseValue() const { return getOperand(2); }
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
private:
friend class Operation;
explicit SelectOp(const Operation *state) : Op(state) {}
};
/// The "store" op writes an element to a memref specified by an index list.
/// The arity of indices is the rank of the memref (i.e. if the memref being
/// stored to is of rank 3, then 3 indices are required for the store following

View File

@ -39,8 +39,8 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(/*opPrefix=*/"", context) {
addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, CmpIOp,
DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp,
LoadOp, MemRefCastOp, MulFOp, MulIOp, StoreOp, SubFOp, SubIOp,
TensorCastOp>();
LoadOp, MemRefCastOp, MulFOp, MulIOp, SelectOp, StoreOp, SubFOp,
SubIOp, TensorCastOp>();
}
//===----------------------------------------------------------------------===//
@ -1085,6 +1085,75 @@ void MulIOp::getCanonicalizationPatterns(OwningPatternList &results,
results.push_back(std::make_unique<SimplifyMulX1>(context));
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
void SelectOp::build(Builder *builder, OperationState *result,
SSAValue *condition, SSAValue *trueValue,
SSAValue *falseValue) {
result->addOperands({condition, trueValue, falseValue});
result->addTypes(trueValue->getType());
}
bool SelectOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<NamedAttribute, 4> attrs;
Type type;
if (parser->parseOperandList(ops, 3) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
auto i1Type = getI1SameShape(&parser->getBuilder(), type);
SmallVector<Type, 3> types = {i1Type, type, type};
return parser->resolveOperands(ops, types, parser->getNameLoc(),
result->operands) ||
parser->addTypeToList(type, result->types);
}
void SelectOp::print(OpAsmPrinter *p) const {
*p << getOperationName() << ' ';
p->printOperands(getOperation()->getOperands());
*p << " : " << getTrueValue()->getType();
p->printOptionalAttrDict(getAttrs());
}
bool SelectOp::verify() const {
auto conditionType = getCondition()->getType();
auto trueType = getTrueValue()->getType();
auto falseType = getFalseValue()->getType();
if (trueType != falseType)
return emitOpError(
"requires 'true' and 'false' arguments to be of the same type");
if (checkI1SameShape(trueType, conditionType))
return emitOpError("requires the condition to have the same shape as "
"arguments with elemental type i1");
return false;
}
Attribute SelectOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() == 3 && "select takes three operands");
// select true, %0, %1 => %0
// select false, %0, %1 => %1
auto cond = operands[0].dyn_cast_or_null<IntegerAttr>();
if (!cond)
return {};
if (cond.getValue().isNullValue()) {
return operands[2];
} else if (cond.getValue().isOneValue()) {
return operands[1];
}
llvm_unreachable("first argument of select must be i1");
}
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//

View File

@ -124,6 +124,21 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32, %i: i32, %idx : index):
// CHECK: %{{[0-9]+}} = cmpi "eq", %cst_5, %cst_5 : vector<42xindex>
%20 = cmpi "eq", %cidx, %cidx : vector<42 x index>
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
%21 = select %18, %idx, %idx : index
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xindex>
%22 = select %19, %tidx, %tidx : tensor<42 x index>
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xindex>
%23 = select %20, %cidx, %cidx : vector<42 x index>
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
%24 = "select"(%18, %idx, %idx) : (i1, index, index) -> index
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xindex>
%25 = "select"(%19, %tidx, %tidx) : (tensor<42 x i1>, tensor<42 x index>, tensor<42 x index>) -> tensor<42 x index>
return
}

View File

@ -230,3 +230,51 @@ bb0:
%r = "cmpi"(%c, %c) {predicate: 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<42 x i32>
}
// -----
cfgfunc @cfgfunc_with_ops(i32, i32, i32) {
bb0(%cond : i32, %t : i32, %f : i32):
// expected-error@+2 {{different type than prior uses}}
// expected-error@-2 {{prior use here}}
%r = select %cond, %t, %f : i32
}
// -----
cfgfunc @cfgfunc_with_ops(i32, i32, i32) {
bb0(%cond : i32, %t : i32, %f : i32):
// expected-error@+1 {{elemental type i1}}
%r = "select"(%cond, %t, %f) : (i32, i32, i32) -> i32
}
// -----
cfgfunc @cfgfunc_with_ops(i1, i32, i64) {
bb0(%cond : i1, %t : i32, %f : i64):
// expected-error@+1 {{'true' and 'false' arguments to be of the same type}}
%r = "select"(%cond, %t, %f) : (i1, i32, i64) -> i32
}
// -----
cfgfunc @cfgfunc_with_ops(i1, vector<42xi32>, vector<42xi32>) {
bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>):
// expected-error@+1 {{requires the condition to have the same shape as arguments}}
%r = "select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
}
// -----
cfgfunc @cfgfunc_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor<?xi32>):
// expected-error@+1 {{'true' and 'false' arguments to be of the same type}}
%r = "select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor<?xi32>) -> tensor<42xi32>
}
// -----
cfgfunc @cfgfunc_with_ops(tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) {
bb0(%cond : tensor<?xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
// expected-error@+1 {{requires the condition to have the same shape as arguments}}
%r = "select"(%cond, %t, %f) : (tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
}