[mlir][Vector] Support 0-D vectors in `CmpIOp`

Following the example of `VectorOfAnyRankOf`, I've done a few changes in the
`.td` files to help with adding the support for the 0-D case gradually.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D115220
This commit is contained in:
Michal Terepeta 2021-12-12 11:44:04 +00:00 committed by Nicolas Vasilache
parent 8e833d081b
commit a0c930d312
5 changed files with 67 additions and 5 deletions

View File

@ -116,6 +116,13 @@ class Arith_CompareOp<string mnemonic, list<OpTrait> traits = []> :
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
}
// Just like `Arith_CompareOp` but also admits 0-D vectors. Introduced
// temporarily to allow gradual transition to 0-D vectors.
class Arith_CompareOpOfAnyRank<string mnemonic, list<OpTrait> traits = []> :
Arith_CompareOp<mnemonic, traits> {
let results = (outs BoolLikeOfAnyRank:$result);
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@ -990,7 +997,7 @@ def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
// CmpIOp
//===----------------------------------------------------------------------===//
def Arith_CmpIOp : Arith_CompareOp<"cmpi"> {
def Arith_CmpIOp : Arith_CompareOpOfAnyRank<"cmpi"> {
let summary = "integer comparison operation";
let description = [{
The `cmpi` operation is a generic comparison for integer-like types. Its two
@ -1057,8 +1064,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi"> {
}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
SignlessIntegerLike:$lhs,
SignlessIntegerLike:$rhs);
SignlessIntegerLikeOfAnyRank:$lhs,
SignlessIntegerLikeOfAnyRank:$rhs);
let builders = [
OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{

View File

@ -213,6 +213,7 @@ def IsVectorTypePred : And<[CPred<"$_self.isa<::mlir::VectorType>()">,
CPred<"$_self.cast<::mlir::VectorType>().getRank() > 0">]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
// Whether a type is a TensorType.
@ -603,7 +604,9 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
class VectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
"::mlir::VectorType">;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
class VectorOfAnyRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
"::mlir::VectorType">;
@ -835,6 +838,14 @@ def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
TensorOf<[I1]>.predicate]>,
"bool-like">;
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def BoolLikeOfAnyRank : TypeConstraint<Or<[
I1.predicate,
VectorOfAnyRankOf<[I1]>.predicate,
TensorOf<[I1]>.predicate]>,
"bool-like">;
// Type constraint for signless-integer-like types: signless integers, indices,
// vectors of signless integers or indices, tensors of signless integers.
def SignlessIntegerLike : TypeConstraint<Or<[
@ -843,6 +854,14 @@ def SignlessIntegerLike : TypeConstraint<Or<[
TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>,
"signless-integer-like">;
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def SignlessIntegerLikeOfAnyRank : TypeConstraint<Or<[
AnySignlessIntegerOrIndex.predicate,
VectorOfAnyRankOf<[AnySignlessIntegerOrIndex]>.predicate,
TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>,
"signless-integer-like">;
// Type constraint for float-like types: floats, vectors or tensors thereof.
def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
VectorOf<[AnyFloat]>.predicate, TensorOf<[AnyFloat]>.predicate]>,

View File

@ -352,6 +352,17 @@ func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
// -----
// CHECK-LABEL: func @cmpi_0dvector(
func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) {
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
// CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
// CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[ARG0]], %[[ARG1]] : vector<1xi32>
%0 = arith.cmpi ult, %arg0, %arg1 : vector<i32>
std.return
}
// -----
// CHECK-LABEL: func @cmpi_2dvector(
func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast

View File

@ -631,6 +631,12 @@ func @test_cmpi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
return %0 : vector<8xi1>
}
// CHECK-LABEL: test_cmpi_vector_0d
func @test_cmpi_vector_0d(%arg0 : vector<i64>, %arg1 : vector<i64>) -> vector<i1> {
%0 = arith.cmpi ult, %arg0, %arg1 : vector<i64>
return %0 : vector<i1>
}
// CHECK-LABEL: test_cmpf
func @test_cmpf(%arg0 : f64, %arg1 : f64) -> i1 {
%0 = arith.cmpf oeq, %arg0, %arg1 : f64

View File

@ -67,7 +67,6 @@ func @bitcast_0d() {
return
}
func @constant_mask_0d() {
%1 = vector.constant_mask [0] : vector<i1>
// CHECK: ( 0 )
@ -78,6 +77,22 @@ func @constant_mask_0d() {
return
}
func @arith_cmpi_0d(%smaller : vector<i32>, %bigger : vector<i32>) {
%0 = arith.cmpi ult, %smaller, %bigger : vector<i32>
// CHECK: ( 1 )
vector.print %0: vector<i1>
%1 = arith.cmpi ugt, %smaller, %bigger : vector<i32>
// CHECK: ( 0 )
vector.print %1: vector<i1>
%2 = arith.cmpi eq, %smaller, %bigger : vector<i32>
// CHECK: ( 0 )
vector.print %2: vector<i1>
return
}
func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
@ -96,5 +111,9 @@ func @entry() {
call @bitcast_0d() : () -> ()
call @constant_mask_0d() : () -> ()
%smaller = arith.constant dense<42> : vector<i32>
%bigger = arith.constant dense<4242> : vector<i32>
call @arith_cmpi_0d(%smaller, %bigger) : (vector<i32>, vector<i32>) -> ()
return
}