forked from OSchip/llvm-project
[mlir][Standard] Allow select to use an i1 for vector and tensor values
It currently requires that the condition match the shape of the selected value, but this is only really useful for things like masks. This revision allows for the use of i1 to mean that all of the vector/tensor is selected. This also matches the behavior of LLVM select. A benefit of this change is that transformations that want to generate selects, like those on the CFG, don't have to special case vector/tensor. Previously the only way to generate a select from an i1 was to use a splat, but that doesn't support dynamically shaped/unranked tensors. Differential Revision: https://reviews.llvm.org/D78690
This commit is contained in:
parent
2fafe7ff59
commit
7f85adb54d
|
@ -1124,7 +1124,8 @@ public:
|
|||
|
||||
/// Compare this range with another.
|
||||
template <typename OtherT> bool operator==(const OtherT &other) const {
|
||||
return size() == std::distance(other.begin(), other.end()) &&
|
||||
return size() ==
|
||||
static_cast<size_t>(std::distance(other.begin(), other.end())) &&
|
||||
std::equal(begin(), end(), other.begin());
|
||||
}
|
||||
template <typename OtherT> bool operator!=(const OtherT &other) const {
|
||||
|
|
|
@ -1915,11 +1915,8 @@ def RsqrtOp : FloatUnaryOp<"rsqrt"> {
|
|||
// SelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
|
||||
AllTypesMatch<["true_value", "false_value", "result"]>,
|
||||
TypesMatchWith<"condition type matches i1 equivalent of result type",
|
||||
"result", "condition",
|
||||
"getI1SameShape($_self)">]> {
|
||||
def SelectOp : Std_Op<"select", [NoSideEffect,
|
||||
AllTypesMatch<["true_value", "false_value", "result"]>]> {
|
||||
let summary = "select operation";
|
||||
let description = [{
|
||||
The `select` operation chooses one value based on a binary condition
|
||||
|
@ -1930,7 +1927,8 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
|
|||
The operation applies to vectors and tensors elementwise given the _shape_
|
||||
of all operands is identical. The choice is made for each element
|
||||
individually based on the value at the same position as the element in the
|
||||
condition operand.
|
||||
condition operand. If an i1 is provided as the condition, the entire vector
|
||||
or tensor is chosen.
|
||||
|
||||
The `select` operation combined with [`cmpi`](#stdcmpi-cmpiop) can be used
|
||||
to implement `min` and `max` with signed or unsigned comparison semantics.
|
||||
|
@ -1944,9 +1942,11 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
|
|||
// Generic form of the same operation.
|
||||
%x = "std.select"(%cond, %true, %false) : (i1, i32, i32) -> i32
|
||||
|
||||
// Vector selection is element-wise
|
||||
%vx = "std.select"(%vcond, %vtrue, %vfalse)
|
||||
: (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32>
|
||||
// Element-wise vector selection.
|
||||
%vx = std.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32>
|
||||
|
||||
// Full vector selection.
|
||||
%vx = std.select %cond, %vtrue, %vfalse : vector<42xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
|
@ -1954,7 +1954,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
|
|||
AnyType:$true_value,
|
||||
AnyType:$false_value);
|
||||
let results = (outs AnyType:$result);
|
||||
let verifier = ?;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value condition,"
|
||||
|
@ -1970,10 +1969,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$condition `,` $true_value `,` $false_value attr-dict `:` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -999,15 +999,6 @@ struct SimplifyCondBranchIdenticalSuccessors
|
|||
if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock())
|
||||
return failure();
|
||||
|
||||
// TODO: ATM Tensor/Vector SelectOp requires that the condition has the same
|
||||
// shape as the operands. We should relax that to allow an i1 to signify
|
||||
// that everything is selected.
|
||||
auto doesntSupportsScalarI1 = [](Type type) {
|
||||
return type.isa<TensorType>() || type.isa<VectorType>();
|
||||
};
|
||||
if (llvm::any_of(trueOperands.getTypes(), doesntSupportsScalarI1))
|
||||
return failure();
|
||||
|
||||
// Generate a select for any operands that differ between the two.
|
||||
SmallVector<Value, 8> mergedOperands;
|
||||
mergedOperands.reserve(trueOperands.size());
|
||||
|
@ -1925,6 +1916,59 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, SelectOp op) {
|
||||
p << "select " << op.getOperands();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : ";
|
||||
if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
|
||||
p << condType << ", ";
|
||||
p << op.getType();
|
||||
}
|
||||
|
||||
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
|
||||
Type conditionType, resultType;
|
||||
SmallVector<OpAsmParser::OperandType, 3> operands;
|
||||
if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(resultType))
|
||||
return failure();
|
||||
|
||||
// Check for the explicit condition type if this is a masked tensor or vector.
|
||||
if (succeeded(parser.parseOptionalComma())) {
|
||||
conditionType = resultType;
|
||||
if (parser.parseType(resultType))
|
||||
return failure();
|
||||
} else {
|
||||
conditionType = parser.getBuilder().getI1Type();
|
||||
}
|
||||
|
||||
result.addTypes(resultType);
|
||||
return parser.resolveOperands(operands,
|
||||
{conditionType, resultType, resultType},
|
||||
parser.getNameLoc(), result.operands);
|
||||
}
|
||||
|
||||
static LogicalResult verify(SelectOp op) {
|
||||
Type conditionType = op.getCondition().getType();
|
||||
if (conditionType.isSignlessInteger(1))
|
||||
return success();
|
||||
|
||||
// If the result type is a vector or tensor, the type can be a mask with the
|
||||
// same elements.
|
||||
Type resultType = op.getType();
|
||||
if (!resultType.isa<TensorType>() && !resultType.isa<VectorType>())
|
||||
return op.emitOpError()
|
||||
<< "expected condition to be a signless i1, but got "
|
||||
<< conditionType;
|
||||
Type shapedConditionType = getI1SameShape(resultType);
|
||||
if (conditionType != shapedConditionType)
|
||||
return op.emitOpError()
|
||||
<< "expected condition type to have the same shape "
|
||||
"as the result type, expected "
|
||||
<< shapedConditionType << ", but got " << conditionType;
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SignExtendIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -69,39 +69,18 @@ func @cond_br_same_successor(%cond : i1, %a : i32) {
|
|||
|
||||
// CHECK-LABEL: func @cond_br_same_successor_insert_select(
|
||||
// CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
|
||||
func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32) -> i32 {
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<2xi32>, %[[ARG3:.*]]: tensor<2xi32>
|
||||
func @cond_br_same_successor_insert_select(
|
||||
%cond : i1, %a : i32, %b : i32, %c : tensor<2xi32>, %d : tensor<2xi32>
|
||||
) -> (i32, tensor<2xi32>) {
|
||||
// CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]]
|
||||
// CHECK: return %[[RES]]
|
||||
// CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG2]], %[[ARG3]]
|
||||
// CHECK: return %[[RES]], %[[RES2]]
|
||||
|
||||
cond_br %cond, ^bb1(%a : i32), ^bb1(%b : i32)
|
||||
cond_br %cond, ^bb1(%a, %c : i32, tensor<2xi32>), ^bb1(%b, %d : i32, tensor<2xi32>)
|
||||
|
||||
^bb1(%result : i32):
|
||||
return %result : i32
|
||||
}
|
||||
|
||||
/// Check that we don't generate a select if the type requires a splat.
|
||||
/// TODO: SelectOp should allow for matching a vector/tensor with i1.
|
||||
|
||||
// CHECK-LABEL: func @cond_br_same_successor_no_select_tensor(
|
||||
func @cond_br_same_successor_no_select_tensor(%cond : i1, %a : tensor<2xi32>,
|
||||
%b : tensor<2xi32>) -> tensor<2xi32>{
|
||||
// CHECK: cond_br
|
||||
|
||||
cond_br %cond, ^bb1(%a : tensor<2xi32>), ^bb1(%b : tensor<2xi32>)
|
||||
|
||||
^bb1(%result : tensor<2xi32>):
|
||||
return %result : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cond_br_same_successor_no_select_vector(
|
||||
func @cond_br_same_successor_no_select_vector(%cond : i1, %a : vector<2xi32>,
|
||||
%b : vector<2xi32>) -> vector<2xi32> {
|
||||
// CHECK: cond_br
|
||||
|
||||
cond_br %cond, ^bb1(%a : vector<2xi32>), ^bb1(%b : vector<2xi32>)
|
||||
|
||||
^bb1(%result : vector<2xi32>):
|
||||
return %result : vector<2xi32>
|
||||
^bb1(%result : i32, %result2 : tensor<2xi32>):
|
||||
return %result, %result2 : i32, tensor<2xi32>
|
||||
}
|
||||
|
||||
/// Test the compound folding of BranchOp and CondBranchOp.
|
||||
|
|
|
@ -141,17 +141,17 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
|
|||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
|
||||
%21 = select %18, %idx, %idx : index
|
||||
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi32>
|
||||
%22 = select %19, %tci32, %tci32 : tensor<42 x i32>
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi1>, tensor<42xi32>
|
||||
%22 = select %19, %tci32, %tci32 : tensor<42 x i1>, tensor<42 x i32>
|
||||
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xi32>
|
||||
%23 = select %20, %vci32, %vci32 : vector<42 x i32>
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xi1>, vector<42xi32>
|
||||
%23 = select %20, %vci32, %vci32 : vector<42 x i1>, vector<42 x i32>
|
||||
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
|
||||
%24 = "std.select"(%18, %idx, %idx) : (i1, index, index) -> index
|
||||
|
||||
// CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi32>
|
||||
%25 = "std.select"(%19, %tci32, %tci32) : (tensor<42 x i1>, tensor<42 x i32>, tensor<42 x i32>) -> tensor<42 x i32>
|
||||
%25 = std.select %18, %tci32, %tci32 : tensor<42 x i32>
|
||||
|
||||
// CHECK: %{{[0-9]+}} = divi_signed %arg2, %arg2 : i32
|
||||
%26 = divi_signed %i, %i : i32
|
||||
|
|
|
@ -281,18 +281,18 @@ func @func_with_ops(i1, i32, i64) {
|
|||
|
||||
// -----
|
||||
|
||||
func @func_with_ops(i1, vector<42xi32>, vector<42xi32>) {
|
||||
^bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>):
|
||||
// expected-error@+1 {{requires the same shape for all operands and results}}
|
||||
%r = "std.select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
|
||||
func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
|
||||
^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
|
||||
// expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'vector<42xi1>', but got 'vector<12xi1>'}}
|
||||
%r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
|
||||
^bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor<?xi32>):
|
||||
// expected-error@+1 {{ op requires the same shape for all operands and results}}
|
||||
%r = "std.select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor<?xi32>) -> tensor<42xi32>
|
||||
func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
|
||||
^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
|
||||
// expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'tensor<42xi1>', but got 'tensor<12xi1>'}}
|
||||
%r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue