[mlir] Fix scalable type translation in splat element attr

LLVM Dialect Constant Op translations assume that if the attribute is a
vector, it's a fixed length one, generating an invalid translation for
constant scalable vector initializations.

Differential Revision: https://reviews.llvm.org/D117125
This commit is contained in:
Javier Setoain 2022-01-12 17:01:07 +00:00
parent 36a5491832
commit 7c56458616
3 changed files with 19 additions and 1 deletions

View File

@ -239,6 +239,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
llvm::Type *elementType;
uint64_t numElements;
bool isScalable = false;
if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
elementType = arrayTy->getElementType();
numElements = arrayTy->getNumElements();
@ -248,6 +249,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
} else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
elementType = sVectorTy->getElementType();
numElements = sVectorTy->getMinNumElements();
isScalable = true;
} else {
llvm_unreachable("unrecognized constant vector type");
}
@ -265,7 +267,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
return nullptr;
if (llvmType->isVectorTy())
return llvm::ConstantVector::getSplat(
llvm::ElementCount::get(numElements, /*Scalable=*/false), child);
llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
if (llvmType->isArrayTy()) {
auto *arrayType = llvm::ArrayType::get(elementType, numElements);
SmallVector<llvm::Constant *, 8> constants(numElements, child);

View File

@ -90,6 +90,8 @@ llvm.func @return_ppi8_42_9() -> !llvm.ptr<ptr<i8, 42>, 9>
llvm.func @return_v4_i32() -> vector<4xi32>
// CHECK: declare <4 x float> @return_v4_float()
llvm.func @return_v4_float() -> vector<4xf32>
// CHECK: declare <vscale x 4 x float> @return_vs_4_float()
llvm.func @return_vs_4_float() -> vector<[4]xf32>
// CHECK: declare <vscale x 4 x i32> @return_vs_4_i32()
llvm.func @return_vs_4_i32() -> !llvm.vec<?x4 x i32>
// CHECK: declare <vscale x 8 x half> @return_vs_8_half()

View File

@ -907,6 +907,13 @@ llvm.func @vector_splat_1d() -> vector<4xf32> {
llvm.return %0 : vector<4xf32>
}
// CHECK-LABEL: @vector_splat_1d_scalable
llvm.func @vector_splat_1d_scalable() -> vector<[4]xf32> {
// CHECK: ret <vscale x 4 x float> zeroinitializer
%0 = llvm.mlir.constant(dense<0.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
llvm.return %0 : vector<[4]xf32>
}
// CHECK-LABEL: @vector_splat_2d
llvm.func @vector_splat_2d() -> !llvm.array<4 x vector<16 x f32>> {
// CHECK: ret [4 x <16 x float>] zeroinitializer
@ -928,6 +935,13 @@ llvm.func @vector_splat_nonzero() -> vector<4xf32> {
llvm.return %0 : vector<4xf32>
}
// CHECK-LABEL: @vector_splat_nonzero_scalable
llvm.func @vector_splat_nonzero_scalable() -> vector<[4]xf32> {
// CHECK: ret <vscale x 4 x float> shufflevector (<vscale x 4 x float> insertelement (<vscale x 4 x float> poison, float 1.000000e+00, i32 0), <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer)
%0 = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
llvm.return %0 : vector<[4]xf32>
}
// CHECK-LABEL: @ops
llvm.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32) -> !llvm.struct<(f32, i32)> {
// CHECK-NEXT: fsub float %0, %1