diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index c75f9fe02312..d34fa9a245d6 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -627,7 +627,37 @@ def Vector_TypeCastOp : }]; } -// TODO(andydavis) Morph this operation into a Vector_MaskOp. +// TODO(andydavis) Add constant folding support. +def Vector_CreateMaskOp : + Vector_Op<"create_mask", [NoSideEffect]>, + Arguments<(ins Variadic:$operands)>, Results<(outs VectorOf<[I1]>)> { + let summary = "creates a vector mask"; + let description = [{ + Creates and returns a vector mask where elements of the result vector + are set to '0' or '1', based on whether the element indices are contained + within a hyper-rectangular region specified by the operands. Specifically, + each operand specifies a range [0, operand-value) for a unique dimension in + the vector result. The conjunction of the operand ranges define + hyper-rectangular region within which elements values are set to 1 + (otherwise element values are set to 0). + + Example: create a vector mask of size 4x3xi1 where elements in range + 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). + + %1 = vector.create_mask %c3, %c2 : vector<4x3xi1> + + print %1 + columns + 0 1 2 + |------------ + 0 | 1 1 0 + rows 1 | 1 1 0 + 2 | 1 1 0 + 3 | 0 0 0 + }]; +} + +// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask def Vector_IndexTupleOp : Vector_Op<"make_index_tuple", [NoSideEffect]>, Arguments<(ins Variadic:$operands)>, diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 6086531e3c70..7f3be9d9fa95 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -995,6 +995,37 @@ static LogicalResult verify(TypeCastOp &op) { return success(); } +//===----------------------------------------------------------------------===// +// CreateMaskOp +//===----------------------------------------------------------------------===// + +ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) { + auto indexType = parser.getBuilder().getIndexType(); + Type resultType; + SmallVector operandInfo; + return failure( + parser.parseOperandList(operandInfo) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(resultType) || + parser.resolveOperands(operandInfo, indexType, result.operands) || + parser.addTypeToList(resultType, result.types)); +} + +static void print(OpAsmPrinter &p, CreateMaskOp &op) { + p << op.getOperationName() << ' '; + p.printOperands(op.operands()); + p << " : " << op.getResult()->getType(); +} + +static LogicalResult verify(CreateMaskOp &op) { + // Verify that an operand was specified for each result vector each dimension. + if (op.getNumOperands() != + op.getResult()->getType().cast().getRank()) + return op.emitOpError( + "must specify an operand for each result vector dimension"); + return success(); +} + //===----------------------------------------------------------------------===// // IndexTupleOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 0fbcb56f3882..0f19033fb42b 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -606,3 +606,13 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> return } + +// ----- + +func @create_mask() { + %c2 = constant 2 : index + %c3 = constant 3 : index + // expected-error@+1 {{must specify an operand for each result vector dimension}} + %0 = vector.create_mask %c3, %c2 : vector<4x3x7xi1> + return +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index 3824dfe20e44..0a52a1ea45b5 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -124,3 +124,15 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> return } + +// CHECK-LABEL: create_vector_mask +func @create_vector_mask() { + // CHECK: %[[C2:.*]] = constant 2 : index + %c2 = constant 2 : index + // CHECK-NEXT: %[[C3:.*]] = constant 3 : index + %c3 = constant 3 : index + // CHECK-NEXT: vector.create_mask %[[C3]], %[[C2]] : vector<4x3xi1> + %0 = vector.create_mask %c3, %c2 : vector<4x3xi1> + + return +}