Add CreateMaskOp to the VectorOps dialect.

PiperOrigin-RevId: 283591888
This commit is contained in:
Andy Davis 2019-12-03 11:55:09 -08:00 committed by A. Unique TensorFlower
parent 67515e8d7a
commit 2c13fd9f17
4 changed files with 84 additions and 1 deletions

View File

@ -627,7 +627,37 @@ def Vector_TypeCastOp :
}]; }];
} }
// TODO(andydavis) Morph this operation into a Vector_MaskOp. // TODO(andydavis) Add constant folding support.
def Vector_CreateMaskOp :
Vector_Op<"create_mask", [NoSideEffect]>,
Arguments<(ins Variadic<Index>:$operands)>, Results<(outs VectorOf<[I1]>)> {
let summary = "creates a vector mask";
let description = [{
Creates and returns a vector mask where elements of the result vector
are set to '0' or '1', based on whether the element indices are contained
within a hyper-rectangular region specified by the operands. Specifically,
each operand specifies a range [0, operand-value) for a unique dimension in
the vector result. The conjunction of the operand ranges define
hyper-rectangular region within which elements values are set to 1
(otherwise element values are set to 0).
Example: create a vector mask of size 4x3xi1 where elements in range
0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
%1 = vector.create_mask %c3, %c2 : vector<4x3xi1>
print %1
columns
0 1 2
|------------
0 | 1 1 0
rows 1 | 1 1 0
2 | 1 1 0
3 | 0 0 0
}];
}
// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask
def Vector_IndexTupleOp : def Vector_IndexTupleOp :
Vector_Op<"make_index_tuple", [NoSideEffect]>, Vector_Op<"make_index_tuple", [NoSideEffect]>,
Arguments<(ins Variadic<Index>:$operands)>, Arguments<(ins Variadic<Index>:$operands)>,

View File

@ -995,6 +995,37 @@ static LogicalResult verify(TypeCastOp &op) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// CreateMaskOp
//===----------------------------------------------------------------------===//
ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) {
auto indexType = parser.getBuilder().getIndexType();
Type resultType;
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
return failure(
parser.parseOperandList(operandInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(resultType) ||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
parser.addTypeToList(resultType, result.types));
}
static void print(OpAsmPrinter &p, CreateMaskOp &op) {
p << op.getOperationName() << ' ';
p.printOperands(op.operands());
p << " : " << op.getResult()->getType();
}
static LogicalResult verify(CreateMaskOp &op) {
// Verify that an operand was specified for each result vector each dimension.
if (op.getNumOperands() !=
op.getResult()->getType().cast<VectorType>().getRank())
return op.emitOpError(
"must specify an operand for each result vector dimension");
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// IndexTupleOp // IndexTupleOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -606,3 +606,13 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return return
} }
// -----
func @create_mask() {
%c2 = constant 2 : index
%c3 = constant 3 : index
// expected-error@+1 {{must specify an operand for each result vector dimension}}
%0 = vector.create_mask %c3, %c2 : vector<4x3x7xi1>
return
}

View File

@ -124,3 +124,15 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
return return
} }
// CHECK-LABEL: create_vector_mask
func @create_vector_mask() {
// CHECK: %[[C2:.*]] = constant 2 : index
%c2 = constant 2 : index
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
%c3 = constant 3 : index
// CHECK-NEXT: vector.create_mask %[[C3]], %[[C2]] : vector<4x3xi1>
%0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
return
}