forked from OSchip/llvm-project
[mlir][Vector] Support 0-D vectors in `CmpIOp`
Following the example of `VectorOfAnyRankOf`, I've done a few changes in the `.td` files to help with adding the support for the 0-D case gradually. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D115220
This commit is contained in:
parent
8e833d081b
commit
a0c930d312
|
@ -116,6 +116,13 @@ class Arith_CompareOp<string mnemonic, list<OpTrait> traits = []> :
|
||||||
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
|
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Just like `Arith_CompareOp` but also admits 0-D vectors. Introduced
|
||||||
|
// temporarily to allow gradual transition to 0-D vectors.
|
||||||
|
class Arith_CompareOpOfAnyRank<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
Arith_CompareOp<mnemonic, traits> {
|
||||||
|
let results = (outs BoolLikeOfAnyRank:$result);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConstantOp
|
// ConstantOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -284,7 +291,7 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi"> {
|
||||||
def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> {
|
def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> {
|
||||||
let summary = "unsigned ceil integer division operation";
|
let summary = "unsigned ceil integer division operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
Unsigned integer division. Rounds towards positive infinity. Treats the
|
Unsigned integer division. Rounds towards positive infinity. Treats the
|
||||||
leading bit as the most significant, i.e. for `i16` given two's complement
|
leading bit as the most significant, i.e. for `i16` given two's complement
|
||||||
representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
|
representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
|
||||||
|
|
||||||
|
@ -990,7 +997,7 @@ def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
|
||||||
// CmpIOp
|
// CmpIOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Arith_CmpIOp : Arith_CompareOp<"cmpi"> {
|
def Arith_CmpIOp : Arith_CompareOpOfAnyRank<"cmpi"> {
|
||||||
let summary = "integer comparison operation";
|
let summary = "integer comparison operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
The `cmpi` operation is a generic comparison for integer-like types. Its two
|
The `cmpi` operation is a generic comparison for integer-like types. Its two
|
||||||
|
@ -1057,8 +1064,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi"> {
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
|
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
|
||||||
SignlessIntegerLike:$lhs,
|
SignlessIntegerLikeOfAnyRank:$lhs,
|
||||||
SignlessIntegerLike:$rhs);
|
SignlessIntegerLikeOfAnyRank:$rhs);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{
|
OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{
|
||||||
|
|
|
@ -213,6 +213,7 @@ def IsVectorTypePred : And<[CPred<"$_self.isa<::mlir::VectorType>()">,
|
||||||
CPred<"$_self.cast<::mlir::VectorType>().getRank() > 0">]>;
|
CPred<"$_self.cast<::mlir::VectorType>().getRank() > 0">]>;
|
||||||
|
|
||||||
// Temporary vector type clone that allows gradual transition to 0-D vectors.
|
// Temporary vector type clone that allows gradual transition to 0-D vectors.
|
||||||
|
// TODO: Remove this when all ops support 0-D vectors.
|
||||||
def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
|
def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
|
||||||
|
|
||||||
// Whether a type is a TensorType.
|
// Whether a type is a TensorType.
|
||||||
|
@ -603,7 +604,9 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
|
||||||
class VectorOf<list<Type> allowedTypes> :
|
class VectorOf<list<Type> allowedTypes> :
|
||||||
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
|
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
|
||||||
"::mlir::VectorType">;
|
"::mlir::VectorType">;
|
||||||
|
|
||||||
// Temporary vector type clone that allows gradual transition to 0-D vectors.
|
// Temporary vector type clone that allows gradual transition to 0-D vectors.
|
||||||
|
// TODO: Remove this when all ops support 0-D vectors.
|
||||||
class VectorOfAnyRankOf<list<Type> allowedTypes> :
|
class VectorOfAnyRankOf<list<Type> allowedTypes> :
|
||||||
ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
|
ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
|
||||||
"::mlir::VectorType">;
|
"::mlir::VectorType">;
|
||||||
|
@ -835,6 +838,14 @@ def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
|
||||||
TensorOf<[I1]>.predicate]>,
|
TensorOf<[I1]>.predicate]>,
|
||||||
"bool-like">;
|
"bool-like">;
|
||||||
|
|
||||||
|
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
|
||||||
|
// TODO: Remove this when all ops support 0-D vectors.
|
||||||
|
def BoolLikeOfAnyRank : TypeConstraint<Or<[
|
||||||
|
I1.predicate,
|
||||||
|
VectorOfAnyRankOf<[I1]>.predicate,
|
||||||
|
TensorOf<[I1]>.predicate]>,
|
||||||
|
"bool-like">;
|
||||||
|
|
||||||
// Type constraint for signless-integer-like types: signless integers, indices,
|
// Type constraint for signless-integer-like types: signless integers, indices,
|
||||||
// vectors of signless integers or indices, tensors of signless integers.
|
// vectors of signless integers or indices, tensors of signless integers.
|
||||||
def SignlessIntegerLike : TypeConstraint<Or<[
|
def SignlessIntegerLike : TypeConstraint<Or<[
|
||||||
|
@ -843,6 +854,14 @@ def SignlessIntegerLike : TypeConstraint<Or<[
|
||||||
TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>,
|
TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>,
|
||||||
"signless-integer-like">;
|
"signless-integer-like">;
|
||||||
|
|
||||||
|
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
|
||||||
|
// TODO: Remove this when all ops support 0-D vectors.
|
||||||
|
def SignlessIntegerLikeOfAnyRank : TypeConstraint<Or<[
|
||||||
|
AnySignlessIntegerOrIndex.predicate,
|
||||||
|
VectorOfAnyRankOf<[AnySignlessIntegerOrIndex]>.predicate,
|
||||||
|
TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>,
|
||||||
|
"signless-integer-like">;
|
||||||
|
|
||||||
// Type constraint for float-like types: floats, vectors or tensors thereof.
|
// Type constraint for float-like types: floats, vectors or tensors thereof.
|
||||||
def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
|
def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
|
||||||
VectorOf<[AnyFloat]>.predicate, TensorOf<[AnyFloat]>.predicate]>,
|
VectorOf<[AnyFloat]>.predicate, TensorOf<[AnyFloat]>.predicate]>,
|
||||||
|
|
|
@ -352,6 +352,17 @@ func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @cmpi_0dvector(
|
||||||
|
func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) {
|
||||||
|
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
|
||||||
|
// CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
|
||||||
|
// CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[ARG0]], %[[ARG1]] : vector<1xi32>
|
||||||
|
%0 = arith.cmpi ult, %arg0, %arg1 : vector<i32>
|
||||||
|
std.return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @cmpi_2dvector(
|
// CHECK-LABEL: func @cmpi_2dvector(
|
||||||
func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
|
func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
|
||||||
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
|
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
|
||||||
|
|
|
@ -631,6 +631,12 @@ func @test_cmpi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
|
||||||
return %0 : vector<8xi1>
|
return %0 : vector<8xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_cmpi_vector_0d
|
||||||
|
func @test_cmpi_vector_0d(%arg0 : vector<i64>, %arg1 : vector<i64>) -> vector<i1> {
|
||||||
|
%0 = arith.cmpi ult, %arg0, %arg1 : vector<i64>
|
||||||
|
return %0 : vector<i1>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: test_cmpf
|
// CHECK-LABEL: test_cmpf
|
||||||
func @test_cmpf(%arg0 : f64, %arg1 : f64) -> i1 {
|
func @test_cmpf(%arg0 : f64, %arg1 : f64) -> i1 {
|
||||||
%0 = arith.cmpf oeq, %arg0, %arg1 : f64
|
%0 = arith.cmpf oeq, %arg0, %arg1 : f64
|
||||||
|
|
|
@ -67,7 +67,6 @@ func @bitcast_0d() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func @constant_mask_0d() {
|
func @constant_mask_0d() {
|
||||||
%1 = vector.constant_mask [0] : vector<i1>
|
%1 = vector.constant_mask [0] : vector<i1>
|
||||||
// CHECK: ( 0 )
|
// CHECK: ( 0 )
|
||||||
|
@ -78,6 +77,22 @@ func @constant_mask_0d() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @arith_cmpi_0d(%smaller : vector<i32>, %bigger : vector<i32>) {
|
||||||
|
%0 = arith.cmpi ult, %smaller, %bigger : vector<i32>
|
||||||
|
// CHECK: ( 1 )
|
||||||
|
vector.print %0: vector<i1>
|
||||||
|
|
||||||
|
%1 = arith.cmpi ugt, %smaller, %bigger : vector<i32>
|
||||||
|
// CHECK: ( 0 )
|
||||||
|
vector.print %1: vector<i1>
|
||||||
|
|
||||||
|
%2 = arith.cmpi eq, %smaller, %bigger : vector<i32>
|
||||||
|
// CHECK: ( 0 )
|
||||||
|
vector.print %2: vector<i1>
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func @entry() {
|
func @entry() {
|
||||||
%0 = arith.constant 42.0 : f32
|
%0 = arith.constant 42.0 : f32
|
||||||
%1 = arith.constant dense<0.0> : vector<f32>
|
%1 = arith.constant dense<0.0> : vector<f32>
|
||||||
|
@ -96,5 +111,9 @@ func @entry() {
|
||||||
call @bitcast_0d() : () -> ()
|
call @bitcast_0d() : () -> ()
|
||||||
call @constant_mask_0d() : () -> ()
|
call @constant_mask_0d() : () -> ()
|
||||||
|
|
||||||
|
%smaller = arith.constant dense<42> : vector<i32>
|
||||||
|
%bigger = arith.constant dense<4242> : vector<i32>
|
||||||
|
call @arith_cmpi_0d(%smaller, %bigger) : (vector<i32>, vector<i32>) -> ()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue