[mlir][Vector] Support 0-D vectors in `ConstantMaskOp`

To support creating both a mask with just a single `true` and `false` values,
I had to relax the restriction in the verifier that the rank is always equal to
the length of the attribute array, in other words, we now allow:

- `vector.constant_mask [0] : vector<i1>` which gets lowered to
  `arith.constant dense<false> : vector<i1>`
- `vector.constant_mask [1] : vector<i1>` which gets lowered to
  `arith.constant dense<true> : vector<i1>`

(the attribute list for the 0-D case must be a singleton containing
either `0` or `1`)

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D115023
This commit is contained in:
Michal Terepeta 2021-12-06 07:59:49 +00:00 committed by Nicolas Vasilache
parent 69bcff46bf
commit caf89c0db6
7 changed files with 84 additions and 4 deletions

View File

@ -2111,7 +2111,7 @@ def Vector_TypeCastOp :
def Vector_ConstantMaskOp : def Vector_ConstantMaskOp :
Vector_Op<"constant_mask", [NoSideEffect]>, Vector_Op<"constant_mask", [NoSideEffect]>,
Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>, Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
Results<(outs VectorOf<[I1]>)> { Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask"; let summary = "creates a constant vector mask";
let description = [{ let description = [{
Creates and returns a vector mask where elements of the result vector Creates and returns a vector mask where elements of the result vector

View File

@ -3924,8 +3924,19 @@ void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(ConstantMaskOp &op) { static LogicalResult verify(ConstantMaskOp &op) {
// Verify that array attr size matches the rank of the vector result.
auto resultType = op.getResult().getType().cast<VectorType>(); auto resultType = op.getResult().getType().cast<VectorType>();
// Check the corner case of 0-D vectors first.
if (resultType.getRank() == 0) {
if (op.mask_dim_sizes().size() != 1)
return op->emitError("array attr must have length 1 for 0-D vectors");
auto dim = op.mask_dim_sizes()[0].cast<IntegerAttr>().getInt();
if (dim != 0 && dim != 1)
return op->emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
return success();
}
// Verify that array attr size matches the rank of the vector result.
if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank()) if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
return op.emitOpError( return op.emitOpError(
"must specify array attr of size equal vector result rank"); "must specify array attr of size equal vector result rank");

View File

@ -960,7 +960,20 @@ public:
auto dstType = op.getType(); auto dstType = op.getType();
auto eltType = dstType.getElementType(); auto eltType = dstType.getElementType();
auto dimSizes = op.mask_dim_sizes(); auto dimSizes = op.mask_dim_sizes();
int64_t rank = dimSizes.size(); int64_t rank = dstType.getRank();
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(
VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
ArrayRef<bool>{value}));
return success();
}
int64_t trueDim = std::min(dstType.getDimSize(0), int64_t trueDim = std::min(dstType.getDimSize(0),
dimSizes[0].cast<IntegerAttr>().getInt()); dimSizes[0].cast<IntegerAttr>().getInt());

View File

@ -1396,6 +1396,26 @@ func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32>
// ----- // -----
func @genbool_0d_f() -> vector<i1> {
%0 = vector.constant_mask [0] : vector<i1>
return %0 : vector<i1>
}
// CHECK-LABEL: func @genbool_0d_f
// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<i1>
// CHECK: return %[[VAL_0]] : vector<i1>
// -----
func @genbool_0d_t() -> vector<i1> {
%0 = vector.constant_mask [1] : vector<i1>
return %0 : vector<i1>
}
// CHECK-LABEL: func @genbool_0d_t
// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<i1>
// CHECK: return %[[VAL_0]] : vector<i1>
// -----
func @genbool_1d() -> vector<8xi1> { func @genbool_1d() -> vector<8xi1> {
%0 = vector.constant_mask [4] : vector<8xi1> %0 = vector.constant_mask [4] : vector<8xi1>
return %0 : vector<8xi1> return %0 : vector<8xi1>

View File

@ -882,6 +882,20 @@ func @create_mask() {
} }
// -----
func @constant_mask_0d_no_attr() {
// expected-error@+1 {{array attr must have length 1 for 0-D vectors}}
%0 = vector.constant_mask [] : vector<i1>
}
// -----
func @constant_mask_0d_bad_attr() {
// expected-error@+1 {{mask dim size must be either 0 or 1 for 0-D vectors}}
%0 = vector.constant_mask [2] : vector<i1>
}
// ----- // -----
func @constant_mask() { func @constant_mask() {

View File

@ -376,6 +376,15 @@ func @create_vector_mask() {
return return
} }
// CHECK-LABEL: @constant_vector_mask_0d
func @constant_vector_mask_0d() {
// CHECK: vector.constant_mask [0] : vector<i1>
%0 = vector.constant_mask [0] : vector<i1>
// CHECK: vector.constant_mask [1] : vector<i1>
%1 = vector.constant_mask [1] : vector<i1>
return
}
// CHECK-LABEL: @constant_vector_mask // CHECK-LABEL: @constant_vector_mask
func @constant_vector_mask() { func @constant_vector_mask() {
// CHECK: vector.constant_mask [3, 2] : vector<4x3xi1> // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>

View File

@ -68,6 +68,16 @@ func @bitcast_0d() {
} }
func @constant_mask_0d() {
%1 = vector.constant_mask [0] : vector<i1>
// CHECK: ( 0 )
vector.print %1: vector<i1>
%2 = vector.constant_mask [1] : vector<i1>
// CHECK: ( 1 )
vector.print %2: vector<i1>
return
}
func @entry() { func @entry() {
%0 = arith.constant 42.0 : f32 %0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32> %1 = arith.constant dense<0.0> : vector<f32>
@ -78,10 +88,13 @@ func @entry() {
call @print_vector_0d(%3) : (vector<f32>) -> () call @print_vector_0d(%3) : (vector<f32>) -> ()
%4 = arith.constant 42.0 : f32 %4 = arith.constant 42.0 : f32
// Warning: these must be called in their textual order of definition in the
// file to not mess up FileCheck.
call @splat_0d(%4) : (f32) -> () call @splat_0d(%4) : (f32) -> ()
call @broadcast_0d(%4) : (f32) -> () call @broadcast_0d(%4) : (f32) -> ()
call @bitcast_0d() : () -> () call @bitcast_0d() : () -> ()
call @constant_mask_0d() : () -> ()
return return
} }