forked from OSchip/llvm-project
StandardOps: introduce 'select'.
The semantics of 'select' is conventional: return the second operand if the first operand is true (1 : i1) and the third operand otherwise. It is applicable to vectors and tensors element-wise, similarly to LLVM instruction. This operation is necessary to implement min/max to lower 'for' loops with complex bounds to CFG functions and to support ternary operations in ML functions. It is preferred to first-class min/max because of its simplicity, e.g. it is not concered with signedness. PiperOrigin-RevId: 223160860
This commit is contained in:
parent
e7f43c8361
commit
a3fb6d0da3
|
@ -1900,11 +1900,45 @@ TODO: In the distant future, this will accept
|
|||
optional attributes for fast math, contraction, rounding mode, and other
|
||||
controls.
|
||||
|
||||
#### 'select' operation {#'select'-operation}
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
operation ::= ssa-id `=` `select` ssa-use, ssa-use, ssa-use `:` type
|
||||
```
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir {.mlir}
|
||||
// Short-hand notation of scalar selection.
|
||||
%x = select %cond, %true, %false : i32
|
||||
|
||||
// Long-hand notation of the same operation.
|
||||
%x = "select"(%cond, %true, %false) : (i1, i32, i32) -> i32
|
||||
|
||||
// Vector selection is element-wise
|
||||
%vx = "select"(%vcond, %vtrue, %vfalse)
|
||||
: (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32>
|
||||
```
|
||||
|
||||
The `select` operation chooses one value based on a binary condition supplied as
|
||||
its first operand. If the value of the first operand is `1`, the second operand
|
||||
is chosen, otherwise the third operand is chosen. The second and the third
|
||||
operand must have the same type.
|
||||
|
||||
The operation applies to vectors and tensors elementwise given the _shape_ of
|
||||
all operands is identical. The choice is made for each element individually
|
||||
based on the value at the same position as the element in the condition operand.
|
||||
|
||||
The `select` operation combined with [`cmpi`](#'cmpi'-operation) can be used to
|
||||
implement `min` and `max` with signed or unsigned comparison semantics.
|
||||
|
||||
#### 'tensor_cast' operation {#'tensor_cast'-operation}
|
||||
|
||||
Syntax:
|
||||
|
||||
```mlir {.mlir}
|
||||
``` {.ebnf}
|
||||
operation ::= ssa-id `=` `tensor_cast` ssa-use `:` type `to` type
|
||||
```
|
||||
|
||||
|
|
|
@ -294,12 +294,39 @@ readability by humans, short-hand notation accepts string literals that are
|
|||
mapped to the underlying integer values: `cmpi "eq", %lhs, %rhs` better implies
|
||||
integer equality comparison than `cmpi 0, %lhs, %rhs` where it is unclear what
|
||||
gets compared to what else. This syntactic sugar is possible thanks to parser
|
||||
logic redifinitions for short-hand notation of non-builtin operations.
|
||||
logic redefinitions for short-hand notation of non-builtin operations.
|
||||
Supporting it in the full notation would have required changing how the main
|
||||
parsing algorithm works and may have unexpected repercussions. While it had been
|
||||
possible to store the predicate as string attribute, it would have rendered
|
||||
impossible to implement switching logic based on the comparison kind and made
|
||||
attribute validity checks (one out of ten possibile kinds) more complex.
|
||||
attribute validity checks (one out of ten possible kinds) more complex.
|
||||
|
||||
### 'select' operation to implement min/max {#select-operation}
|
||||
|
||||
Although `min` and `max` operations are likely to occur as a result of
|
||||
transforming affine loops in ML functions, we did not make them first-class
|
||||
operations. Instead, we provide the `select` operation that can be combined with
|
||||
`cmpi` to implement the minimum and maximum computation. Although they now
|
||||
require two operations, they are likely to be emitted automatically during the
|
||||
transformation inside MLIR. On the other hand, there are multiple benefits of
|
||||
introducing `select`: standalone min/max would concern themselves with the
|
||||
signedness of the comparison, already taken into account by `cmpi`; `select` can
|
||||
support floats transparently if used after a float-comparison operation; the
|
||||
lower-level targets provide `select`-like instructions making the translation
|
||||
trivial.
|
||||
|
||||
This operation could have been implemented with additional control flow: `%r =
|
||||
select %cond, %t, %f` is equivalent to
|
||||
|
||||
```mlir
|
||||
bb0:
|
||||
br_cond %cond, bb1(%t), bb1(%f)
|
||||
bb1(%r):
|
||||
```
|
||||
|
||||
However, this control flow granularity is not available in the ML functions
|
||||
where min/max, and thus `select`, are likely to appear. In addition, simpler
|
||||
control flow may be beneficial for optimization in general.
|
||||
|
||||
### Quantized integer operations {#quantized-integer-operations}
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "third_party/llvm/llvm/projects/google-mlir/include/mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
class Builder;
|
||||
|
@ -638,6 +637,32 @@ private:
|
|||
explicit MulIOp(const Operation *state) : BinaryOp(state) {}
|
||||
};
|
||||
|
||||
class SelectOp : public Op<SelectOp, OpTrait::NOperands<3>::Impl,
|
||||
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static StringRef getOperationName() { return "select"; }
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *condition, SSAValue *trueValue,
|
||||
SSAValue *falseValue);
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
bool verify() const;
|
||||
|
||||
SSAValue *getCondition() { return getOperand(0); }
|
||||
const SSAValue *getCondition() const { return getOperand(0); }
|
||||
SSAValue *getTrueValue() { return getOperand(1); }
|
||||
const SSAValue *getTrueValue() const { return getOperand(1); }
|
||||
SSAValue *getFalseValue() { return getOperand(2); }
|
||||
const SSAValue *getFalseValue() const { return getOperand(2); }
|
||||
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
explicit SelectOp(const Operation *state) : Op(state) {}
|
||||
};
|
||||
|
||||
/// The "store" op writes an element to a memref specified by an index list.
|
||||
/// The arity of indices is the rank of the memref (i.e. if the memref being
|
||||
/// stored to is of rank 3, then 3 indices are required for the store following
|
||||
|
|
|
@ -39,8 +39,8 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
|||
: Dialect(/*opPrefix=*/"", context) {
|
||||
addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, CmpIOp,
|
||||
DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp,
|
||||
LoadOp, MemRefCastOp, MulFOp, MulIOp, StoreOp, SubFOp, SubIOp,
|
||||
TensorCastOp>();
|
||||
LoadOp, MemRefCastOp, MulFOp, MulIOp, SelectOp, StoreOp, SubFOp,
|
||||
SubIOp, TensorCastOp>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1085,6 +1085,75 @@ void MulIOp::getCanonicalizationPatterns(OwningPatternList &results,
|
|||
results.push_back(std::make_unique<SimplifyMulX1>(context));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void SelectOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *condition, SSAValue *trueValue,
|
||||
SSAValue *falseValue) {
|
||||
result->addOperands({condition, trueValue, falseValue});
|
||||
result->addTypes(trueValue->getType());
|
||||
}
|
||||
|
||||
bool SelectOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> ops;
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
Type type;
|
||||
|
||||
if (parser->parseOperandList(ops, 3) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(type))
|
||||
return true;
|
||||
|
||||
auto i1Type = getI1SameShape(&parser->getBuilder(), type);
|
||||
SmallVector<Type, 3> types = {i1Type, type, type};
|
||||
return parser->resolveOperands(ops, types, parser->getNameLoc(),
|
||||
result->operands) ||
|
||||
parser->addTypeToList(type, result->types);
|
||||
}
|
||||
|
||||
void SelectOp::print(OpAsmPrinter *p) const {
|
||||
*p << getOperationName() << ' ';
|
||||
p->printOperands(getOperation()->getOperands());
|
||||
*p << " : " << getTrueValue()->getType();
|
||||
p->printOptionalAttrDict(getAttrs());
|
||||
}
|
||||
|
||||
bool SelectOp::verify() const {
|
||||
auto conditionType = getCondition()->getType();
|
||||
auto trueType = getTrueValue()->getType();
|
||||
auto falseType = getFalseValue()->getType();
|
||||
|
||||
if (trueType != falseType)
|
||||
return emitOpError(
|
||||
"requires 'true' and 'false' arguments to be of the same type");
|
||||
|
||||
if (checkI1SameShape(trueType, conditionType))
|
||||
return emitOpError("requires the condition to have the same shape as "
|
||||
"arguments with elemental type i1");
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Attribute SelectOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
assert(operands.size() == 3 && "select takes three operands");
|
||||
|
||||
// select true, %0, %1 => %0
|
||||
// select false, %0, %1 => %1
|
||||
auto cond = operands[0].dyn_cast_or_null<IntegerAttr>();
|
||||
if (!cond)
|
||||
return {};
|
||||
|
||||
if (cond.getValue().isNullValue()) {
|
||||
return operands[2];
|
||||
} else if (cond.getValue().isOneValue()) {
|
||||
return operands[1];
|
||||
}
|
||||
|
||||
llvm_unreachable("first argument of select must be i1");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -124,6 +124,21 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32, %i: i32, %idx : index):
|
|||
// CHECK: %{{[0-9]+}} = cmpi "eq", %cst_5, %cst_5 : vector<42xindex>
|
||||
%20 = cmpi "eq", %cidx, %cidx : vector<42 x index>
|
||||
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
|
||||
%21 = select %18, %idx, %idx : index
|
||||
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xindex>
|
||||
%22 = select %19, %tidx, %tidx : tensor<42 x index>
|
||||
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xindex>
|
||||
%23 = select %20, %cidx, %cidx : vector<42 x index>
|
||||
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
|
||||
%24 = "select"(%18, %idx, %idx) : (i1, index, index) -> index
|
||||
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xindex>
|
||||
%25 = "select"(%19, %tidx, %tidx) : (tensor<42 x i1>, tensor<42 x index>, tensor<42 x index>) -> tensor<42 x index>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -230,3 +230,51 @@ bb0:
|
|||
%r = "cmpi"(%c, %c) {predicate: 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<42 x i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @cfgfunc_with_ops(i32, i32, i32) {
|
||||
bb0(%cond : i32, %t : i32, %f : i32):
|
||||
// expected-error@+2 {{different type than prior uses}}
|
||||
// expected-error@-2 {{prior use here}}
|
||||
%r = select %cond, %t, %f : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @cfgfunc_with_ops(i32, i32, i32) {
|
||||
bb0(%cond : i32, %t : i32, %f : i32):
|
||||
// expected-error@+1 {{elemental type i1}}
|
||||
%r = "select"(%cond, %t, %f) : (i32, i32, i32) -> i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @cfgfunc_with_ops(i1, i32, i64) {
|
||||
bb0(%cond : i1, %t : i32, %f : i64):
|
||||
// expected-error@+1 {{'true' and 'false' arguments to be of the same type}}
|
||||
%r = "select"(%cond, %t, %f) : (i1, i32, i64) -> i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @cfgfunc_with_ops(i1, vector<42xi32>, vector<42xi32>) {
|
||||
bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>):
|
||||
// expected-error@+1 {{requires the condition to have the same shape as arguments}}
|
||||
%r = "select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @cfgfunc_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
|
||||
bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor<?xi32>):
|
||||
// expected-error@+1 {{'true' and 'false' arguments to be of the same type}}
|
||||
%r = "select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor<?xi32>) -> tensor<42xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @cfgfunc_with_ops(tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) {
|
||||
bb0(%cond : tensor<?xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
|
||||
// expected-error@+1 {{requires the condition to have the same shape as arguments}}
|
||||
%r = "select"(%cond, %t, %f) : (tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue