[mlir][LLVM] Make the nested type restriction on complex constants less aggressive

Complex nested in other types is perfectly fine, just nested structs
aren't supported. Instead of checking whether there's nesting just check
whether the struct we're dealing with is a complex number.

Differential Revision: https://reviews.llvm.org/D125381
This commit is contained in:
Benjamin Kramer 2022-05-11 15:12:21 +02:00
parent 2a40cc532b
commit 27dad99622
4 changed files with 34 additions and 15 deletions

View File

@ -351,8 +351,7 @@ SetVector<Block *> getTopologicallySortedBlocks(Region &region);
/// report it to `loc` and return nullptr. /// report it to `loc` and return nullptr.
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
Location loc, Location loc,
const ModuleTranslation &moduleTranslation, const ModuleTranslation &moduleTranslation);
bool isTopLevel = true);
/// Creates a call to an LLVM IR intrinsic function with the given arguments. /// Creates a call to an LLVM IR intrinsic function with the given arguments.
llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,

View File

@ -214,7 +214,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
(type.isa<VectorType>() || hasVectorElementType)) { (type.isa<VectorType>() || hasVectorElementType)) {
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc, innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
moduleTranslation, /*isTopLevel=*/false); moduleTranslation);
llvm::Constant *splatVector = llvm::Constant *splatVector =
llvm::ConstantDataVector::getSplat(0, splatValue); llvm::ConstantDataVector::getSplat(0, splatValue);
SmallVector<llvm::Constant *> constants(numAggregates, splatVector); SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
@ -272,22 +272,22 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
/// report it to `loc` and return nullptr. /// report it to `loc` and return nullptr.
llvm::Constant *mlir::LLVM::detail::getLLVMConstant( llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::Type *llvmType, Attribute attr, Location loc, llvm::Type *llvmType, Attribute attr, Location loc,
const ModuleTranslation &moduleTranslation, bool isTopLevel) { const ModuleTranslation &moduleTranslation) {
if (!attr) if (!attr)
return llvm::UndefValue::get(llvmType); return llvm::UndefValue::get(llvmType);
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
if (!isTopLevel) { auto arrayAttr = attr.dyn_cast<ArrayAttr>();
emitError(loc, "nested struct types are not supported in constants"); if (!arrayAttr || arrayAttr.size() != 2) {
emitError(loc, "expected struct type to be a complex number");
return nullptr; return nullptr;
} }
auto arrayAttr = attr.cast<ArrayAttr>();
llvm::Type *elementType = structType->getElementType(0); llvm::Type *elementType = structType->getElementType(0);
llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc, llvm::Constant *real =
moduleTranslation, false); getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
if (!real) if (!real)
return nullptr; return nullptr;
llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc, llvm::Constant *imag =
moduleTranslation, false); getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
if (!imag) if (!imag)
return nullptr; return nullptr;
return llvm::ConstantStruct::get(structType, {real, imag}); return llvm::ConstantStruct::get(structType, {real, imag});
@ -336,7 +336,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
elementType, elementType,
elementTypeSequential ? splatAttr elementTypeSequential ? splatAttr
: splatAttr.getSplatValue<Attribute>(), : splatAttr.getSplatValue<Attribute>(),
loc, moduleTranslation, false); loc, moduleTranslation);
if (!child) if (!child)
return nullptr; return nullptr;
if (llvmType->isVectorTy()) if (llvmType->isVectorTy())
@ -367,7 +367,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::Type *innermostType = getInnermostElementType(llvmType); llvm::Type *innermostType = getInnermostElementType(llvmType);
for (auto n : elementsAttr.getValues<Attribute>()) { for (auto n : elementsAttr.getValues<Attribute>()) {
constants.push_back( constants.push_back(
getLLVMConstant(innermostType, n, loc, moduleTranslation, false)); getLLVMConstant(innermostType, n, loc, moduleTranslation));
if (!constants.back()) if (!constants.back())
return nullptr; return nullptr;
} }

View File

@ -41,14 +41,22 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
// ----- // -----
llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
// expected-error @+1 {{nested struct types are not supported in constants}} // expected-error @+1 {{expected struct type to be a complex number}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
} }
// ----- // -----
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
// expected-error @+1 {{expected struct type to be a complex number}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
}
// -----
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> { llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{FloatAttr does not match expected type of the constant}} // expected-error @+1 {{FloatAttr does not match expected type of the constant}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)> %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>

View File

@ -1122,6 +1122,18 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
llvm.return %1 : !llvm.struct<(i32, i32)> llvm.return %1 : !llvm.struct<(i32, i32)>
} }
llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
%1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
// CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
}
llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
%1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
// CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
}
llvm.func @noreach() { llvm.func @noreach() {
// CHECK: unreachable // CHECK: unreachable
llvm.unreachable llvm.unreachable