forked from OSchip/llvm-project
[VectorOps] Add a ShuffleOp to the VectorOps dialect
For example %0 = vector.shuffle %x, %y [3 : i32, 2 : i32, 1 : i32, 0 : i32] : vector<2xf32>, vector<2xf32> yields a vector<4xf32> result with a permutation of the elements of %x and %y PiperOrigin-RevId: 284657191
This commit is contained in:
parent
0e963b9c42
commit
1fe65688d4
|
@ -214,6 +214,59 @@ def Vector_BroadcastOp :
|
|||
}];
|
||||
}
|
||||
|
||||
def Vector_ShuffleOp :
|
||||
Vector_Op<"shuffle", [NoSideEffect,
|
||||
PredOpTrait<"first operand v1 and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"second operand v2 and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 1>>]>,
|
||||
Arguments<(ins AnyVector:$v1, AnyVector:$v2, I32ArrayAttr:$mask)>,
|
||||
Results<(outs AnyVector:$vector)> {
|
||||
let summary = "shuffle operation";
|
||||
let description = [{
|
||||
The shuffle operation constructs a permutation (or duplication) of elements
|
||||
from two input vectors, returning a vector with the same element type as
|
||||
the input and a length that is the same as the shuffle mask. The two input
|
||||
vectors must have the same element type, rank, and trailing dimension sizes
|
||||
and shuffles their values in the leading dimension (which may differ in size)
|
||||
according to the given mask. The legality rules are:
|
||||
* the two operands must have the same element type as the result
|
||||
* the two operands and the result must have the same rank and trailing
|
||||
dimension sizes, viz. given two k-D operands
|
||||
v1 : <s_1 x s_2 x .. x s_k x type> and
|
||||
v2 : <t_1 x t_2 x .. x t_k x type>
|
||||
we have s_i = t_i for all 1 < i <= k
|
||||
* the mask length equals the leading dimension size of the result
|
||||
* numbering the input vector indices left to right accross the operands, all
|
||||
mask values must be within range, viz. given two k-D operands v1 and v2
|
||||
above, all mask values are in the range [0,s_1+t_1)
|
||||
|
||||
Examples:
|
||||
```
|
||||
%0 = vector.shuffle %a, %b[0:i32, 3:i32]
|
||||
: vector<2xf32>, vector<2xf32> ; yields vector<2xf32>
|
||||
%1 = vector.shuffle %c, %b[0:i32, 1:i32, 2:i32]
|
||||
: vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32>
|
||||
%2 = vector.shuffle %a, %b[3:i32, 2:i32, 1:i32 : 0:i32]
|
||||
: vector<2xf32>, vector<2xf32> ; yields vector<4xf32>
|
||||
|
||||
```
|
||||
}];
|
||||
let builders = [OpBuilder<"Builder *builder, OperationState &result, Value *v1, Value *v2, ArrayRef<int32_t>">];
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getMaskAttrName() { return "mask"; }
|
||||
VectorType getV1VectorType() {
|
||||
return v1()->getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getV2VectorType() {
|
||||
return v2()->getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Vector_ExtractOp :
|
||||
Vector_Op<"extract", [NoSideEffect,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
|
|
|
@ -458,6 +458,92 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser,
|
|||
parser.addTypeToList(vectorType, result.types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShuffleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ShuffleOp::build(Builder *builder, OperationState &result, Value *v1,
|
||||
Value *v2, ArrayRef<int32_t> mask) {
|
||||
result.addOperands({v1, v2});
|
||||
auto maskAttr = builder->getI32ArrayAttr(mask);
|
||||
result.addTypes(v1->getType());
|
||||
result.addAttribute(getMaskAttrName(), maskAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ShuffleOp op) {
|
||||
p << op.getOperationName() << " " << *op.v1() << ", " << *op.v2() << " "
|
||||
<< op.mask();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()});
|
||||
p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ShuffleOp op) {
|
||||
VectorType resultType = op.getVectorType();
|
||||
VectorType v1Type = op.getV1VectorType();
|
||||
VectorType v2Type = op.getV2VectorType();
|
||||
// Verify ranks.
|
||||
int64_t resRank = resultType.getRank();
|
||||
int64_t v1Rank = v1Type.getRank();
|
||||
int64_t v2Rank = v2Type.getRank();
|
||||
if (resRank != v1Rank || v1Rank != v2Rank)
|
||||
return op.emitOpError("rank mismatch");
|
||||
// Verify all but leading dimension sizes.
|
||||
for (int64_t r = 1; r < v1Rank; ++r) {
|
||||
int64_t resDim = resultType.getDimSize(r);
|
||||
int64_t v1Dim = v1Type.getDimSize(r);
|
||||
int64_t v2Dim = v2Type.getDimSize(r);
|
||||
if (resDim != v1Dim || v1Dim != v2Dim)
|
||||
return op.emitOpError("dimension mismatch");
|
||||
}
|
||||
// Verify mask length.
|
||||
auto maskAttr = op.mask().getValue();
|
||||
int64_t maskLength = maskAttr.size();
|
||||
if (maskLength != resultType.getDimSize(0))
|
||||
return op.emitOpError("mask length mismatch");
|
||||
// Verify all indices.
|
||||
int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
|
||||
for (auto en : llvm::enumerate(maskAttr)) {
|
||||
auto attr = en.value().dyn_cast<IntegerAttr>();
|
||||
if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
|
||||
return op.emitOpError("mask index #")
|
||||
<< (en.index() + 1) << " out of range";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType v1, v2;
|
||||
Attribute attr;
|
||||
VectorType v1Type, v2Type;
|
||||
if (parser.parseOperand(v1) || parser.parseComma() ||
|
||||
parser.parseOperand(v2) ||
|
||||
parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
|
||||
result.attributes) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(v1Type) || parser.parseComma() ||
|
||||
parser.parseType(v2Type) ||
|
||||
parser.resolveOperand(v1, v1Type, result.operands) ||
|
||||
parser.resolveOperand(v2, v2Type, result.operands))
|
||||
return failure();
|
||||
// Construct resulting type: leading dimension matches mask length,
|
||||
// all trailing dimensions match the operands.
|
||||
auto maskAttr = attr.dyn_cast<ArrayAttr>();
|
||||
if (!maskAttr)
|
||||
return parser.emitError(parser.getNameLoc(), "missing mask attribute");
|
||||
int64_t maskLength = maskAttr.size();
|
||||
if (maskLength <= 0)
|
||||
return parser.emitError(parser.getNameLoc(), "invalid mask length");
|
||||
int64_t v1Rank = v1Type.getRank();
|
||||
SmallVector<int64_t, 4> shape;
|
||||
shape.reserve(v1Rank);
|
||||
shape.push_back(maskLength);
|
||||
for (int64_t r = 1; r < v1Rank; ++r)
|
||||
shape.push_back(v1Type.getDimSize(r));
|
||||
VectorType resType = VectorType::get(shape, v1Type.getElementType());
|
||||
parser.addTypeToList(resType, result.types);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -235,18 +235,18 @@ func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32>
|
|||
return %0 : vector<3x16xf32>
|
||||
}
|
||||
// CHECK-LABEL: extract_vec_2d_from_vec_3d
|
||||
// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
|
||||
// CHECK: llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]">
|
||||
// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
|
||||
// CHECK: llvm.return {{.*}} : !llvm<"[3 x <16 x float>]">
|
||||
|
||||
func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
|
||||
%0 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
|
||||
return %0 : f32
|
||||
}
|
||||
// CHECK-LABEL: extract_element_from_vec_3d
|
||||
// CHECK: llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
|
||||
// CHECK: llvm.extractvalue {{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
|
||||
// CHECK: llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||
// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
|
||||
// CHECK: llvm.return %{{.*}} : !llvm.float
|
||||
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
|
||||
// CHECK: llvm.return {{.*}} : !llvm.float
|
||||
|
||||
func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
|
||||
%0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
|
||||
|
|
|
@ -31,6 +31,41 @@ func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) {
|
||||
// expected-error@+1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}}
|
||||
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @shuffle_rank_mismatch(%arg0: vector<2xf32>, %arg1: vector<4x2xf32>) {
|
||||
// expected-error@+1 {{'vector.shuffle' op rank mismatch}}
|
||||
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<4x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @shuffle_trailing_dim_size_mismatch(%arg0: vector<2x2xf32>, %arg1: vector<2x4xf32>) {
|
||||
// expected-error@+1 {{'vector.shuffle' op dimension mismatch}}
|
||||
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2x2xf32>, vector<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
|
||||
// expected-error@+1 {{'vector.shuffle' op mask index #2 out of range}}
|
||||
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 4 : i32] : vector<2xf32>, vector<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
|
||||
// expected-error@+1 {{custom op 'vector.shuffle' invalid mask length}}
|
||||
%1 = vector.shuffle %arg0, %arg1 [] : vector<2xf32>, vector<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extract_vector_type(%arg0: index) {
|
||||
// expected-error@+1 {{expected vector type}}
|
||||
%1 = vector.extract %arg0[] : index
|
||||
|
|
|
@ -24,20 +24,38 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
|
|||
|
||||
// CHECK-LABEL: @vector_broadcast
|
||||
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
|
||||
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
|
||||
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
|
||||
%0 = vector.broadcast %a : f32 to vector<16xf32>
|
||||
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
|
||||
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
|
||||
%1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
|
||||
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
|
||||
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
|
||||
%2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32>
|
||||
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
|
||||
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
|
||||
%3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32>
|
||||
return %3 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @shuffle1D
|
||||
func @shuffle1D(%a: vector<2xf32>, %b: vector<4xf32>) -> vector<2xf32> {
|
||||
// CHECK: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32, 3 : i32] : vector<2xf32>, vector<2xf32>
|
||||
%1 = vector.shuffle %a, %a[0 : i32, 1 : i32, 2: i32, 3 : i32] : vector<2xf32>, vector<2xf32>
|
||||
// CHECK-NEXT: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32] : vector<4xf32>, vector<4xf32>
|
||||
%2 = vector.shuffle %1, %b[0 : i32, 1 : i32, 2 : i32] : vector<4xf32>, vector<4xf32>
|
||||
// CHECK-NEXT: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 6 : i32] : vector<3xf32>, vector<4xf32>
|
||||
%3 = vector.shuffle %2, %b[0 : i32, 6 : i32] : vector<3xf32>, vector<4xf32>
|
||||
return %3 : vector<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @shuffle2D
|
||||
func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
|
||||
// CHECK: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32] : vector<1x4xf32>, vector<2x4xf32>
|
||||
%1 = vector.shuffle %a, %b[0 : i32, 1 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
|
||||
return %1 : vector<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extract
|
||||
func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
|
||||
// CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32>
|
||||
// CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32>
|
||||
%1 = vector.extract %arg0[3 : i32] : vector<4x8x16xf32>
|
||||
// CHECK-NEXT: vector.extract {{.*}}[3 : i32, 3 : i32] : vector<4x8x16xf32>
|
||||
%2 = vector.extract %arg0[3 : i32, 3 : i32] : vector<4x8x16xf32>
|
||||
|
@ -47,35 +65,35 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @insert
|
||||
func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) {
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
|
||||
func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
|
||||
%1 = vector.insert %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
|
||||
%2 = vector.insert %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
|
||||
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
|
||||
%3 = vector.insert %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
|
||||
return
|
||||
return %3 : vector<4x8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @outerproduct
|
||||
func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
|
||||
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
|
||||
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
|
||||
%0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
|
||||
// CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32>
|
||||
// CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32>
|
||||
%1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32>
|
||||
return %1 : vector<4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @insert_strided_slice
|
||||
func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
|
||||
// CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
|
||||
// CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
|
||||
%1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @strided_slice
|
||||
func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
|
||||
// CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
|
||||
// CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
|
||||
%1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
|
||||
return %1: vector<2x2x16xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue