[mlir][LLVM] Allow scalable vectors in ShuffleVectorOp

The current implementation of ShuffleVectorOp assumes all vectors are
scalable. LLVM IR allows shufflevector operations on scalable vectors,
and the current translation between LLVM Dialect and LLVM IR does the
rigth thing when the shuffle mask is all zeroes. This is required to
do a splat operation on a scalable vector, but it doesn't make sense
for scalable vectors outside of that operation, i.e.: with non-all zero
masks.

Differential Revision: https://reviews.llvm.org/D118371
This commit is contained in:
Javier Setoain 2022-01-27 13:36:16 +00:00
parent e41a138520
commit cd0d21b47b
4 changed files with 52 additions and 4 deletions

View File

@ -1881,8 +1881,9 @@ void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
Value v1, Value v2, ArrayAttr mask,
ArrayRef<NamedAttribute> attrs) {
auto containerType = v1.getType();
auto vType = LLVM::getFixedVectorType(
LLVM::getVectorElementType(containerType), mask.size());
auto vType = LLVM::getVectorType(
LLVM::getVectorElementType(containerType), mask.size(),
containerType.cast<VectorType>().isScalable());
build(b, result, vType, v1, v2, mask);
result.addAttributes(attrs);
}
@ -1914,8 +1915,9 @@ ParseResult ShuffleVectorOp::parse(OpAsmParser &parser,
if (!LLVM::isCompatibleVectorType(typeV1))
return parser.emitError(
loc, "expected LLVM IR dialect vector type for operand #1");
auto vType = LLVM::getFixedVectorType(LLVM::getVectorElementType(typeV1),
maskAttr.size());
auto vType =
LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(),
typeV1.cast<VectorType>().isScalable());
result.addTypes(vType);
return success();
}
@ -1925,6 +1927,11 @@ LogicalResult ShuffleVectorOp::verify() {
Type type2 = getV2().getType();
if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2))
return emitOpError("expected matching LLVM IR Dialect element types");
if (LLVM::isScalableVectorType(type1))
if (llvm::any_of(getMask(), [](Attribute attr) {
return attr.cast<IntegerAttr>().getInt() != 0;
}))
return emitOpError("expected a splat operation for scalable vectors");
return success();
}

View File

@ -1250,3 +1250,11 @@ func @gep_out_of_bounds(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx
llvm.getelementptr %ptr[%idx, 1, 3] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
return
}
// -----
func @non_splat_shuffle_on_scalable_vector(%arg0: vector<[4]xf32>) {
// expected-error@+1 {{expected a splat operation for scalable vectors}}
%0 = llvm.shufflevector %arg0, %arg0 [0 : i32, 0 : i32, 0 : i32, 1 : i32] : vector<[4]xf32>, vector<[4]xf32>
return
}

View File

@ -281,6 +281,19 @@ func @vect(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
return
}
// CHECK-LABEL: @scalable_vect
func @scalable_vect(%arg0: vector<[4]xf32>, %arg1: i32, %arg2: f32) {
// CHECK: = llvm.extractelement {{.*}} : vector<[4]xf32>
%0 = llvm.extractelement %arg0[%arg1 : i32] : vector<[4]xf32>
// CHECK: = llvm.insertelement {{.*}} : vector<[4]xf32>
%1 = llvm.insertelement %arg2, %arg0[%arg1 : i32] : vector<[4]xf32>
// CHECK: = llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[4]xf32>, vector<[4]xf32>
%2 = llvm.shufflevector %arg0, %arg0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[4]xf32>, vector<[4]xf32>
// CHECK: = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
%3 = llvm.mlir.constant(dense<1.0> : vector<[4]xf32>) : vector<[4]xf32>
return
}
// CHECK-LABEL: @alloca
func @alloca(%size : i64) {
// CHECK: llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr<i32>

View File

@ -1168,6 +1168,26 @@ llvm.func @vect_i64idx(%arg0: vector<4xf32>, %arg1: i64, %arg2: f32) {
llvm.return
}
// CHECK-LABEL: @scalable_vect
llvm.func @scalable_vect(%arg0: vector<[4]xf32>, %arg1: i32, %arg2: f32) {
// CHECK-NEXT: extractelement <vscale x 4 x float> {{.*}}, i32
// CHECK-NEXT: insertelement <vscale x 4 x float> {{.*}}, float %2, i32
// CHECK-NEXT: shufflevector <vscale x 4 x float> %0, <vscale x 4 x float> %0, <vscale x 4 x i32> zeroinitializer
%0 = llvm.extractelement %arg0[%arg1 : i32] : vector<[4]xf32>
%1 = llvm.insertelement %arg2, %arg0[%arg1 : i32] : vector<[4]xf32>
%2 = llvm.shufflevector %arg0, %arg0 [0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<[4]xf32>, vector<[4]xf32>
llvm.return
}
// CHECK-LABEL: @scalable_vect_i64idx
llvm.func @scalable_vect_i64idx(%arg0: vector<[4]xf32>, %arg1: i64, %arg2: f32) {
// CHECK-NEXT: extractelement <vscale x 4 x float> {{.*}}, i64
// CHECK-NEXT: insertelement <vscale x 4 x float> {{.*}}, float %2, i64
%0 = llvm.extractelement %arg0[%arg1 : i64] : vector<[4]xf32>
%1 = llvm.insertelement %arg2, %arg0[%arg1 : i64] : vector<[4]xf32>
llvm.return
}
// CHECK-LABEL: @alloca
llvm.func @alloca(%size : i64) {
// Alignment automatically set by the LLVM IR builder when alignment attribute