forked from OSchip/llvm-project
[mlir] Fix scalable type translation in splat element attr
LLVM Dialect Constant Op translations assume that if the attribute is a vector, it's a fixed length one, generating an invalid translation for constant scalable vector initializations. Differential Revision: https://reviews.llvm.org/D117125
This commit is contained in:
parent
36a5491832
commit
7c56458616
|
@ -239,6 +239,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
|||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||
llvm::Type *elementType;
|
||||
uint64_t numElements;
|
||||
bool isScalable = false;
|
||||
if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
|
||||
elementType = arrayTy->getElementType();
|
||||
numElements = arrayTy->getNumElements();
|
||||
|
@ -248,6 +249,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
|||
} else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
|
||||
elementType = sVectorTy->getElementType();
|
||||
numElements = sVectorTy->getMinNumElements();
|
||||
isScalable = true;
|
||||
} else {
|
||||
llvm_unreachable("unrecognized constant vector type");
|
||||
}
|
||||
|
@ -265,7 +267,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
|||
return nullptr;
|
||||
if (llvmType->isVectorTy())
|
||||
return llvm::ConstantVector::getSplat(
|
||||
llvm::ElementCount::get(numElements, /*Scalable=*/false), child);
|
||||
llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
|
||||
if (llvmType->isArrayTy()) {
|
||||
auto *arrayType = llvm::ArrayType::get(elementType, numElements);
|
||||
SmallVector<llvm::Constant *, 8> constants(numElements, child);
|
||||
|
|
|
@ -90,6 +90,8 @@ llvm.func @return_ppi8_42_9() -> !llvm.ptr<ptr<i8, 42>, 9>
|
|||
llvm.func @return_v4_i32() -> vector<4xi32>
|
||||
// CHECK: declare <4 x float> @return_v4_float()
|
||||
llvm.func @return_v4_float() -> vector<4xf32>
|
||||
// CHECK: declare <vscale x 4 x float> @return_vs_4_float()
|
||||
llvm.func @return_vs_4_float() -> vector<[4]xf32>
|
||||
// CHECK: declare <vscale x 4 x i32> @return_vs_4_i32()
|
||||
llvm.func @return_vs_4_i32() -> !llvm.vec<?x4 x i32>
|
||||
// CHECK: declare <vscale x 8 x half> @return_vs_8_half()
|
||||
|
|
|
@ -907,6 +907,13 @@ llvm.func @vector_splat_1d() -> vector<4xf32> {
|
|||
llvm.return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_splat_1d_scalable
|
||||
llvm.func @vector_splat_1d_scalable() -> vector<[4]xf32> {
|
||||
// CHECK: ret <vscale x 4 x float> zeroinitializer
|
||||
%0 = llvm.mlir.constant(dense<0.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
|
||||
llvm.return %0 : vector<[4]xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_splat_2d
|
||||
llvm.func @vector_splat_2d() -> !llvm.array<4 x vector<16 x f32>> {
|
||||
// CHECK: ret [4 x <16 x float>] zeroinitializer
|
||||
|
@ -928,6 +935,13 @@ llvm.func @vector_splat_nonzero() -> vector<4xf32> {
|
|||
llvm.return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_splat_nonzero_scalable
|
||||
llvm.func @vector_splat_nonzero_scalable() -> vector<[4]xf32> {
|
||||
// CHECK: ret <vscale x 4 x float> shufflevector (<vscale x 4 x float> insertelement (<vscale x 4 x float> poison, float 1.000000e+00, i32 0), <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer)
|
||||
%0 = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
|
||||
llvm.return %0 : vector<[4]xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ops
|
||||
llvm.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32) -> !llvm.struct<(f32, i32)> {
|
||||
// CHECK-NEXT: fsub float %0, %1
|
||||
|
|
Loading…
Reference in New Issue