forked from OSchip/llvm-project
Add CreateMaskOp to the VectorOps dialect.
PiperOrigin-RevId: 283591888
This commit is contained in:
parent
67515e8d7a
commit
2c13fd9f17
|
@ -627,7 +627,37 @@ def Vector_TypeCastOp :
|
|||
}];
|
||||
}
|
||||
|
||||
// TODO(andydavis) Morph this operation into a Vector_MaskOp.
|
||||
// TODO(andydavis) Add constant folding support.
|
||||
def Vector_CreateMaskOp :
|
||||
Vector_Op<"create_mask", [NoSideEffect]>,
|
||||
Arguments<(ins Variadic<Index>:$operands)>, Results<(outs VectorOf<[I1]>)> {
|
||||
let summary = "creates a vector mask";
|
||||
let description = [{
|
||||
Creates and returns a vector mask where elements of the result vector
|
||||
are set to '0' or '1', based on whether the element indices are contained
|
||||
within a hyper-rectangular region specified by the operands. Specifically,
|
||||
each operand specifies a range [0, operand-value) for a unique dimension in
|
||||
the vector result. The conjunction of the operand ranges define
|
||||
hyper-rectangular region within which elements values are set to 1
|
||||
(otherwise element values are set to 0).
|
||||
|
||||
Example: create a vector mask of size 4x3xi1 where elements in range
|
||||
0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
|
||||
|
||||
%1 = vector.create_mask %c3, %c2 : vector<4x3xi1>
|
||||
|
||||
print %1
|
||||
columns
|
||||
0 1 2
|
||||
|------------
|
||||
0 | 1 1 0
|
||||
rows 1 | 1 1 0
|
||||
2 | 1 1 0
|
||||
3 | 0 0 0
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask
|
||||
def Vector_IndexTupleOp :
|
||||
Vector_Op<"make_index_tuple", [NoSideEffect]>,
|
||||
Arguments<(ins Variadic<Index>:$operands)>,
|
||||
|
|
|
@ -995,6 +995,37 @@ static LogicalResult verify(TypeCastOp &op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CreateMaskOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) {
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
Type resultType;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
||||
return failure(
|
||||
parser.parseOperandList(operandInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(resultType) ||
|
||||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
|
||||
parser.addTypeToList(resultType, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, CreateMaskOp &op) {
|
||||
p << op.getOperationName() << ' ';
|
||||
p.printOperands(op.operands());
|
||||
p << " : " << op.getResult()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(CreateMaskOp &op) {
|
||||
// Verify that an operand was specified for each result vector each dimension.
|
||||
if (op.getNumOperands() !=
|
||||
op.getResult()->getType().cast<VectorType>().getRank())
|
||||
return op.emitOpError(
|
||||
"must specify an operand for each result vector dimension");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IndexTupleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -606,3 +606,13 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
|||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @create_mask() {
|
||||
%c2 = constant 2 : index
|
||||
%c3 = constant 3 : index
|
||||
// expected-error@+1 {{must specify an operand for each result vector dimension}}
|
||||
%0 = vector.create_mask %c3, %c2 : vector<4x3x7xi1>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -124,3 +124,15 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
|||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: create_vector_mask
|
||||
func @create_vector_mask() {
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
%c2 = constant 2 : index
|
||||
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
|
||||
%c3 = constant 3 : index
|
||||
// CHECK-NEXT: vector.create_mask %[[C3]], %[[C2]] : vector<4x3xi1>
|
||||
%0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
|
||||
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue