forked from OSchip/llvm-project
Add CreateMaskOp to the VectorOps dialect.
PiperOrigin-RevId: 283591888
This commit is contained in:
parent
67515e8d7a
commit
2c13fd9f17
|
@ -627,7 +627,37 @@ def Vector_TypeCastOp :
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(andydavis) Morph this operation into a Vector_MaskOp.
|
// TODO(andydavis) Add constant folding support.
|
||||||
|
def Vector_CreateMaskOp :
|
||||||
|
Vector_Op<"create_mask", [NoSideEffect]>,
|
||||||
|
Arguments<(ins Variadic<Index>:$operands)>, Results<(outs VectorOf<[I1]>)> {
|
||||||
|
let summary = "creates a vector mask";
|
||||||
|
let description = [{
|
||||||
|
Creates and returns a vector mask where elements of the result vector
|
||||||
|
are set to '0' or '1', based on whether the element indices are contained
|
||||||
|
within a hyper-rectangular region specified by the operands. Specifically,
|
||||||
|
each operand specifies a range [0, operand-value) for a unique dimension in
|
||||||
|
the vector result. The conjunction of the operand ranges define
|
||||||
|
hyper-rectangular region within which elements values are set to 1
|
||||||
|
(otherwise element values are set to 0).
|
||||||
|
|
||||||
|
Example: create a vector mask of size 4x3xi1 where elements in range
|
||||||
|
0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
|
||||||
|
|
||||||
|
%1 = vector.create_mask %c3, %c2 : vector<4x3xi1>
|
||||||
|
|
||||||
|
print %1
|
||||||
|
columns
|
||||||
|
0 1 2
|
||||||
|
|------------
|
||||||
|
0 | 1 1 0
|
||||||
|
rows 1 | 1 1 0
|
||||||
|
2 | 1 1 0
|
||||||
|
3 | 0 0 0
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask
|
||||||
def Vector_IndexTupleOp :
|
def Vector_IndexTupleOp :
|
||||||
Vector_Op<"make_index_tuple", [NoSideEffect]>,
|
Vector_Op<"make_index_tuple", [NoSideEffect]>,
|
||||||
Arguments<(ins Variadic<Index>:$operands)>,
|
Arguments<(ins Variadic<Index>:$operands)>,
|
||||||
|
|
|
@ -995,6 +995,37 @@ static LogicalResult verify(TypeCastOp &op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// CreateMaskOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) {
|
||||||
|
auto indexType = parser.getBuilder().getIndexType();
|
||||||
|
Type resultType;
|
||||||
|
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
||||||
|
return failure(
|
||||||
|
parser.parseOperandList(operandInfo) ||
|
||||||
|
parser.parseOptionalAttrDict(result.attributes) ||
|
||||||
|
parser.parseColonType(resultType) ||
|
||||||
|
parser.resolveOperands(operandInfo, indexType, result.operands) ||
|
||||||
|
parser.addTypeToList(resultType, result.types));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print(OpAsmPrinter &p, CreateMaskOp &op) {
|
||||||
|
p << op.getOperationName() << ' ';
|
||||||
|
p.printOperands(op.operands());
|
||||||
|
p << " : " << op.getResult()->getType();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verify(CreateMaskOp &op) {
|
||||||
|
// Verify that an operand was specified for each result vector each dimension.
|
||||||
|
if (op.getNumOperands() !=
|
||||||
|
op.getResult()->getType().cast<VectorType>().getRank())
|
||||||
|
return op.emitOpError(
|
||||||
|
"must specify an operand for each result vector dimension");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// IndexTupleOp
|
// IndexTupleOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -606,3 +606,13 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @create_mask() {
|
||||||
|
%c2 = constant 2 : index
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
// expected-error@+1 {{must specify an operand for each result vector dimension}}
|
||||||
|
%0 = vector.create_mask %c3, %c2 : vector<4x3x7xi1>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -124,3 +124,15 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: create_vector_mask
|
||||||
|
func @create_vector_mask() {
|
||||||
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||||
|
%c2 = constant 2 : index
|
||||||
|
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
// CHECK-NEXT: vector.create_mask %[[C3]], %[[C2]] : vector<4x3xi1>
|
||||||
|
%0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue