[mlir][Vector] Support 0-D vectors in `ConstantMaskOp`

To support creating both a mask with just a single `true` and `false` values,
I had to relax the restriction in the verifier that the rank is always equal to
the length of the attribute array, in other words, we now allow:

- `vector.constant_mask [0] : vector<i1>` which gets lowered to
  `arith.constant dense<false> : vector<i1>`
- `vector.constant_mask [1] : vector<i1>` which gets lowered to
  `arith.constant dense<true> : vector<i1>`

(the attribute list for the 0-D case must be a singleton containing
either `0` or `1`)

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D115023
This commit is contained in:
Michal Terepeta 2021-12-06 07:59:49 +00:00 committed by Nicolas Vasilache
parent 69bcff46bf
commit caf89c0db6
7 changed files with 84 additions and 4 deletions

View File

@ -2111,7 +2111,7 @@ def Vector_TypeCastOp :
def Vector_ConstantMaskOp :
Vector_Op<"constant_mask", [NoSideEffect]>,
Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
Results<(outs VectorOf<[I1]>)> {
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask";
let description = [{
Creates and returns a vector mask where elements of the result vector

View File

@ -3924,8 +3924,19 @@ void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
//===----------------------------------------------------------------------===//
static LogicalResult verify(ConstantMaskOp &op) {
// Verify that array attr size matches the rank of the vector result.
auto resultType = op.getResult().getType().cast<VectorType>();
// Check the corner case of 0-D vectors first.
if (resultType.getRank() == 0) {
if (op.mask_dim_sizes().size() != 1)
return op->emitError("array attr must have length 1 for 0-D vectors");
auto dim = op.mask_dim_sizes()[0].cast<IntegerAttr>().getInt();
if (dim != 0 && dim != 1)
return op->emitError(
"mask dim size must be either 0 or 1 for 0-D vectors");
return success();
}
// Verify that array attr size matches the rank of the vector result.
if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
return op.emitOpError(
"must specify array attr of size equal vector result rank");

View File

@ -960,7 +960,20 @@ public:
auto dstType = op.getType();
auto eltType = dstType.getElementType();
auto dimSizes = op.mask_dim_sizes();
int64_t rank = dimSizes.size();
int64_t rank = dstType.getRank();
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(
VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
ArrayRef<bool>{value}));
return success();
}
int64_t trueDim = std::min(dstType.getDimSize(0),
dimSizes[0].cast<IntegerAttr>().getInt());

View File

@ -1396,6 +1396,26 @@ func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32>
// -----
func @genbool_0d_f() -> vector<i1> {
%0 = vector.constant_mask [0] : vector<i1>
return %0 : vector<i1>
}
// CHECK-LABEL: func @genbool_0d_f
// CHECK: %[[VAL_0:.*]] = arith.constant dense<false> : vector<i1>
// CHECK: return %[[VAL_0]] : vector<i1>
// -----
func @genbool_0d_t() -> vector<i1> {
%0 = vector.constant_mask [1] : vector<i1>
return %0 : vector<i1>
}
// CHECK-LABEL: func @genbool_0d_t
// CHECK: %[[VAL_0:.*]] = arith.constant dense<true> : vector<i1>
// CHECK: return %[[VAL_0]] : vector<i1>
// -----
func @genbool_1d() -> vector<8xi1> {
%0 = vector.constant_mask [4] : vector<8xi1>
return %0 : vector<8xi1>

View File

@ -882,6 +882,20 @@ func @create_mask() {
}
// -----
func @constant_mask_0d_no_attr() {
// expected-error@+1 {{array attr must have length 1 for 0-D vectors}}
%0 = vector.constant_mask [] : vector<i1>
}
// -----
func @constant_mask_0d_bad_attr() {
// expected-error@+1 {{mask dim size must be either 0 or 1 for 0-D vectors}}
%0 = vector.constant_mask [2] : vector<i1>
}
// -----
func @constant_mask() {

View File

@ -376,6 +376,15 @@ func @create_vector_mask() {
return
}
// CHECK-LABEL: @constant_vector_mask_0d
func @constant_vector_mask_0d() {
// CHECK: vector.constant_mask [0] : vector<i1>
%0 = vector.constant_mask [0] : vector<i1>
// CHECK: vector.constant_mask [1] : vector<i1>
%1 = vector.constant_mask [1] : vector<i1>
return
}
// CHECK-LABEL: @constant_vector_mask
func @constant_vector_mask() {
// CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>

View File

@ -68,6 +68,16 @@ func @bitcast_0d() {
}
func @constant_mask_0d() {
%1 = vector.constant_mask [0] : vector<i1>
// CHECK: ( 0 )
vector.print %1: vector<i1>
%2 = vector.constant_mask [1] : vector<i1>
// CHECK: ( 1 )
vector.print %2: vector<i1>
return
}
func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
@ -78,10 +88,13 @@ func @entry() {
call @print_vector_0d(%3) : (vector<f32>) -> ()
%4 = arith.constant 42.0 : f32
// Warning: these must be called in their textual order of definition in the
// file to not mess up FileCheck.
call @splat_0d(%4) : (f32) -> ()
call @broadcast_0d(%4) : (f32) -> ()
call @bitcast_0d() : () -> ()
call @constant_mask_0d() : () -> ()
return
}