[mlir][LLVM] Make the nested type restriction on complex constants less aggressive

Complex nested in other types is perfectly fine, just nested structs
aren't supported. Instead of checking whether there's nesting just check
whether the struct we're dealing with is a complex number.

Differential Revision: https://reviews.llvm.org/D125381
This commit is contained in:
Benjamin Kramer 2022-05-11 15:12:21 +02:00
parent 2a40cc532b
commit 27dad99622
4 changed files with 34 additions and 15 deletions

View File

@ -351,8 +351,7 @@ SetVector<Block *> getTopologicallySortedBlocks(Region &region);
/// report it to `loc` and return nullptr.
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
Location loc,
const ModuleTranslation &moduleTranslation,
bool isTopLevel = true);
const ModuleTranslation &moduleTranslation);
/// Creates a call to an LLVM IR intrinsic function with the given arguments.
llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,

View File

@ -214,7 +214,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
(type.isa<VectorType>() || hasVectorElementType)) {
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
moduleTranslation, /*isTopLevel=*/false);
moduleTranslation);
llvm::Constant *splatVector =
llvm::ConstantDataVector::getSplat(0, splatValue);
SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
@ -272,22 +272,22 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
/// report it to `loc` and return nullptr.
llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::Type *llvmType, Attribute attr, Location loc,
const ModuleTranslation &moduleTranslation, bool isTopLevel) {
const ModuleTranslation &moduleTranslation) {
if (!attr)
return llvm::UndefValue::get(llvmType);
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
if (!isTopLevel) {
emitError(loc, "nested struct types are not supported in constants");
auto arrayAttr = attr.dyn_cast<ArrayAttr>();
if (!arrayAttr || arrayAttr.size() != 2) {
emitError(loc, "expected struct type to be a complex number");
return nullptr;
}
auto arrayAttr = attr.cast<ArrayAttr>();
llvm::Type *elementType = structType->getElementType(0);
llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc,
moduleTranslation, false);
llvm::Constant *real =
getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
if (!real)
return nullptr;
llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc,
moduleTranslation, false);
llvm::Constant *imag =
getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
if (!imag)
return nullptr;
return llvm::ConstantStruct::get(structType, {real, imag});
@ -336,7 +336,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
elementType,
elementTypeSequential ? splatAttr
: splatAttr.getSplatValue<Attribute>(),
loc, moduleTranslation, false);
loc, moduleTranslation);
if (!child)
return nullptr;
if (llvmType->isVectorTy())
@ -367,7 +367,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::Type *innermostType = getInnermostElementType(llvmType);
for (auto n : elementsAttr.getValues<Attribute>()) {
constants.push_back(
getLLVMConstant(innermostType, n, loc, moduleTranslation, false));
getLLVMConstant(innermostType, n, loc, moduleTranslation));
if (!constants.back())
return nullptr;
}

View File

@ -41,14 +41,22 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
// -----
llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
// expected-error @+1 {{nested struct types are not supported in constants}}
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
// expected-error @+1 {{expected struct type to be a complex number}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
}
// -----
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
// expected-error @+1 {{expected struct type to be a complex number}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
}
// -----
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{FloatAttr does not match expected type of the constant}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>

View File

@ -1122,6 +1122,18 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
llvm.return %1 : !llvm.struct<(i32, i32)>
}
llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
%1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
// CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
}
llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
%1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
// CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
}
llvm.func @noreach() {
// CHECK: unreachable
llvm.unreachable