[VectorOps] Add a ShuffleOp to the VectorOps dialect

For example

 %0 = vector.shuffle %x, %y [3 : i32, 2 : i32, 1 : i32, 0 : i32] : vector<2xf32>, vector<2xf32>

yields a vector<4xf32> result with a permutation of the elements of %x and %y

PiperOrigin-RevId: 284657191
This commit is contained in:
Aart Bik 2019-12-09 16:15:02 -08:00 committed by A. Unique TensorFlower
parent 0e963b9c42
commit 1fe65688d4
5 changed files with 211 additions and 19 deletions

View File

@ -214,6 +214,59 @@ def Vector_BroadcastOp :
}];
}
def Vector_ShuffleOp :
Vector_Op<"shuffle", [NoSideEffect,
PredOpTrait<"first operand v1 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"second operand v2 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>]>,
Arguments<(ins AnyVector:$v1, AnyVector:$v2, I32ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
let summary = "shuffle operation";
let description = [{
The shuffle operation constructs a permutation (or duplication) of elements
from two input vectors, returning a vector with the same element type as
the input and a length that is the same as the shuffle mask. The two input
vectors must have the same element type, rank, and trailing dimension sizes
and shuffles their values in the leading dimension (which may differ in size)
according to the given mask. The legality rules are:
* the two operands must have the same element type as the result
* the two operands and the result must have the same rank and trailing
dimension sizes, viz. given two k-D operands
v1 : <s_1 x s_2 x .. x s_k x type> and
v2 : <t_1 x t_2 x .. x t_k x type>
we have s_i = t_i for all 1 < i <= k
* the mask length equals the leading dimension size of the result
* numbering the input vector indices left to right accross the operands, all
mask values must be within range, viz. given two k-D operands v1 and v2
above, all mask values are in the range [0,s_1+t_1)
Examples:
```
%0 = vector.shuffle %a, %b[0:i32, 3:i32]
: vector<2xf32>, vector<2xf32> ; yields vector<2xf32>
%1 = vector.shuffle %c, %b[0:i32, 1:i32, 2:i32]
: vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32>
%2 = vector.shuffle %a, %b[3:i32, 2:i32, 1:i32 : 0:i32]
: vector<2xf32>, vector<2xf32> ; yields vector<4xf32>
```
}];
let builders = [OpBuilder<"Builder *builder, OperationState &result, Value *v1, Value *v2, ArrayRef<int32_t>">];
let extraClassDeclaration = [{
static StringRef getMaskAttrName() { return "mask"; }
VectorType getV1VectorType() {
return v1()->getType().cast<VectorType>();
}
VectorType getV2VectorType() {
return v2()->getType().cast<VectorType>();
}
VectorType getVectorType() {
return vector()->getType().cast<VectorType>();
}
}];
}
def Vector_ExtractOp :
Vector_Op<"extract", [NoSideEffect,
PredOpTrait<"operand and result have same element type",

View File

@ -458,6 +458,92 @@ static ParseResult parseBroadcastOp(OpAsmParser &parser,
parser.addTypeToList(vectorType, result.types));
}
//===----------------------------------------------------------------------===//
// ShuffleOp
//===----------------------------------------------------------------------===//
void ShuffleOp::build(Builder *builder, OperationState &result, Value *v1,
Value *v2, ArrayRef<int32_t> mask) {
result.addOperands({v1, v2});
auto maskAttr = builder->getI32ArrayAttr(mask);
result.addTypes(v1->getType());
result.addAttribute(getMaskAttrName(), maskAttr);
}
static void print(OpAsmPrinter &p, ShuffleOp op) {
p << op.getOperationName() << " " << *op.v1() << ", " << *op.v2() << " "
<< op.mask();
p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()});
p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
}
static LogicalResult verify(ShuffleOp op) {
VectorType resultType = op.getVectorType();
VectorType v1Type = op.getV1VectorType();
VectorType v2Type = op.getV2VectorType();
// Verify ranks.
int64_t resRank = resultType.getRank();
int64_t v1Rank = v1Type.getRank();
int64_t v2Rank = v2Type.getRank();
if (resRank != v1Rank || v1Rank != v2Rank)
return op.emitOpError("rank mismatch");
// Verify all but leading dimension sizes.
for (int64_t r = 1; r < v1Rank; ++r) {
int64_t resDim = resultType.getDimSize(r);
int64_t v1Dim = v1Type.getDimSize(r);
int64_t v2Dim = v2Type.getDimSize(r);
if (resDim != v1Dim || v1Dim != v2Dim)
return op.emitOpError("dimension mismatch");
}
// Verify mask length.
auto maskAttr = op.mask().getValue();
int64_t maskLength = maskAttr.size();
if (maskLength != resultType.getDimSize(0))
return op.emitOpError("mask length mismatch");
// Verify all indices.
int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
for (auto en : llvm::enumerate(maskAttr)) {
auto attr = en.value().dyn_cast<IntegerAttr>();
if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
return op.emitOpError("mask index #")
<< (en.index() + 1) << " out of range";
}
return success();
}
static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType v1, v2;
Attribute attr;
VectorType v1Type, v2Type;
if (parser.parseOperand(v1) || parser.parseComma() ||
parser.parseOperand(v2) ||
parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(v1Type) || parser.parseComma() ||
parser.parseType(v2Type) ||
parser.resolveOperand(v1, v1Type, result.operands) ||
parser.resolveOperand(v2, v2Type, result.operands))
return failure();
// Construct resulting type: leading dimension matches mask length,
// all trailing dimensions match the operands.
auto maskAttr = attr.dyn_cast<ArrayAttr>();
if (!maskAttr)
return parser.emitError(parser.getNameLoc(), "missing mask attribute");
int64_t maskLength = maskAttr.size();
if (maskLength <= 0)
return parser.emitError(parser.getNameLoc(), "invalid mask length");
int64_t v1Rank = v1Type.getRank();
SmallVector<int64_t, 4> shape;
shape.reserve(v1Rank);
shape.push_back(maskLength);
for (int64_t r = 1; r < v1Rank; ++r)
shape.push_back(v1Type.getDimSize(r));
VectorType resType = VectorType::get(shape, v1Type.getElementType());
parser.addTypeToList(resType, result.types);
return success();
}
//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//

View File

@ -235,18 +235,18 @@ func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32>
return %0 : vector<3x16xf32>
}
// CHECK-LABEL: extract_vec_2d_from_vec_3d
// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
// CHECK: llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]">
// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"[3 x <16 x float>]">
func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
%0 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
return %0 : f32
}
// CHECK-LABEL: extract_element_from_vec_3d
// CHECK: llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
// CHECK: llvm.extractvalue {{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
// CHECK: llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
// CHECK: llvm.return %{{.*}} : !llvm.float
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
// CHECK: llvm.return {{.*}} : !llvm.float
func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
%0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>

View File

@ -31,6 +31,41 @@ func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
// -----
func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) {
// expected-error@+1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}}
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<2xi32>
}
// -----
func @shuffle_rank_mismatch(%arg0: vector<2xf32>, %arg1: vector<4x2xf32>) {
// expected-error@+1 {{'vector.shuffle' op rank mismatch}}
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2xf32>, vector<4x2xf32>
}
// -----
func @shuffle_trailing_dim_size_mismatch(%arg0: vector<2x2xf32>, %arg1: vector<2x4xf32>) {
// expected-error@+1 {{'vector.shuffle' op dimension mismatch}}
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 1 : i32] : vector<2x2xf32>, vector<2x4xf32>
}
// -----
func @shuffle_index_out_of_range(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
// expected-error@+1 {{'vector.shuffle' op mask index #2 out of range}}
%1 = vector.shuffle %arg0, %arg1 [0 : i32, 4 : i32] : vector<2xf32>, vector<2xf32>
}
// -----
func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
// expected-error@+1 {{custom op 'vector.shuffle' invalid mask length}}
%1 = vector.shuffle %arg0, %arg1 [] : vector<2xf32>, vector<2xf32>
}
// -----
func @extract_vector_type(%arg0: index) {
// expected-error@+1 {{expected vector type}}
%1 = vector.extract %arg0[] : index

View File

@ -24,20 +24,38 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
// CHECK-LABEL: @vector_broadcast
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
%0 = vector.broadcast %a : f32 to vector<16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
%1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
%2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
%3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32>
return %3 : vector<8x16xf32>
}
// CHECK-LABEL: @shuffle1D
func @shuffle1D(%a: vector<2xf32>, %b: vector<4xf32>) -> vector<2xf32> {
// CHECK: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32, 3 : i32] : vector<2xf32>, vector<2xf32>
%1 = vector.shuffle %a, %a[0 : i32, 1 : i32, 2: i32, 3 : i32] : vector<2xf32>, vector<2xf32>
// CHECK-NEXT: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32] : vector<4xf32>, vector<4xf32>
%2 = vector.shuffle %1, %b[0 : i32, 1 : i32, 2 : i32] : vector<4xf32>, vector<4xf32>
// CHECK-NEXT: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 6 : i32] : vector<3xf32>, vector<4xf32>
%3 = vector.shuffle %2, %b[0 : i32, 6 : i32] : vector<3xf32>, vector<4xf32>
return %3 : vector<2xf32>
}
// CHECK-LABEL: @shuffle2D
func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
// CHECK: vector.shuffle %{{.*}}, %{{.*}}[0 : i32, 1 : i32, 2 : i32] : vector<1x4xf32>, vector<2x4xf32>
%1 = vector.shuffle %a, %b[0 : i32, 1 : i32, 2: i32] : vector<1x4xf32>, vector<2x4xf32>
return %1 : vector<3x4xf32>
}
// CHECK-LABEL: @extract
func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
// CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32>
// CHECK: vector.extract {{.*}}[3 : i32] : vector<4x8x16xf32>
%1 = vector.extract %arg0[3 : i32] : vector<4x8x16xf32>
// CHECK-NEXT: vector.extract {{.*}}[3 : i32, 3 : i32] : vector<4x8x16xf32>
%2 = vector.extract %arg0[3 : i32, 3 : i32] : vector<4x8x16xf32>
@ -47,35 +65,35 @@ func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f
}
// CHECK-LABEL: @insert
func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) {
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
func @insert(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
%1 = vector.insert %c, %res[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
%2 = vector.insert %b, %res[3 : i32, 3 : i32] : vector<16xf32> into vector<4x8x16xf32>
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
// CHECK: vector.insert %{{.*}}, %{{.*}}[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
%3 = vector.insert %a, %res[3 : i32, 3 : i32, 3 : i32] : f32 into vector<4x8x16xf32>
return
return %3 : vector<4x8x16xf32>
}
// CHECK-LABEL: @outerproduct
func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
%0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
// CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32>
// CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32>
%1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32>
return %1 : vector<4x8xf32>
}
// CHECK-LABEL: @insert_strided_slice
func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
// CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
// CHECK: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
%1 = vector.insert_strided_slice %a, %b {offsets = [2, 2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
return
}
// CHECK-LABEL: @strided_slice
func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
// CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
// CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
%1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
return %1: vector<2x2x16xf32>
}