forked from OSchip/llvm-project
[mlir][LLVM] Make the nested type restriction on complex constants less aggressive
Complex nested in other types is perfectly fine, just nested structs aren't supported. Instead of checking whether there's nesting just check whether the struct we're dealing with is a complex number. Differential Revision: https://reviews.llvm.org/D125381
This commit is contained in:
parent
2a40cc532b
commit
27dad99622
|
@ -351,8 +351,7 @@ SetVector<Block *> getTopologicallySortedBlocks(Region ®ion);
|
|||
/// report it to `loc` and return nullptr.
|
||||
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
|
||||
Location loc,
|
||||
const ModuleTranslation &moduleTranslation,
|
||||
bool isTopLevel = true);
|
||||
const ModuleTranslation &moduleTranslation);
|
||||
|
||||
/// Creates a call to an LLVM IR intrinsic function with the given arguments.
|
||||
llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
|
||||
|
|
|
@ -214,7 +214,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
|
|||
(type.isa<VectorType>() || hasVectorElementType)) {
|
||||
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
|
||||
innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
|
||||
moduleTranslation, /*isTopLevel=*/false);
|
||||
moduleTranslation);
|
||||
llvm::Constant *splatVector =
|
||||
llvm::ConstantDataVector::getSplat(0, splatValue);
|
||||
SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
|
||||
|
@ -272,22 +272,22 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
|
|||
/// report it to `loc` and return nullptr.
|
||||
llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
||||
llvm::Type *llvmType, Attribute attr, Location loc,
|
||||
const ModuleTranslation &moduleTranslation, bool isTopLevel) {
|
||||
const ModuleTranslation &moduleTranslation) {
|
||||
if (!attr)
|
||||
return llvm::UndefValue::get(llvmType);
|
||||
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
|
||||
if (!isTopLevel) {
|
||||
emitError(loc, "nested struct types are not supported in constants");
|
||||
auto arrayAttr = attr.dyn_cast<ArrayAttr>();
|
||||
if (!arrayAttr || arrayAttr.size() != 2) {
|
||||
emitError(loc, "expected struct type to be a complex number");
|
||||
return nullptr;
|
||||
}
|
||||
auto arrayAttr = attr.cast<ArrayAttr>();
|
||||
llvm::Type *elementType = structType->getElementType(0);
|
||||
llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc,
|
||||
moduleTranslation, false);
|
||||
llvm::Constant *real =
|
||||
getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
|
||||
if (!real)
|
||||
return nullptr;
|
||||
llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc,
|
||||
moduleTranslation, false);
|
||||
llvm::Constant *imag =
|
||||
getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
|
||||
if (!imag)
|
||||
return nullptr;
|
||||
return llvm::ConstantStruct::get(structType, {real, imag});
|
||||
|
@ -336,7 +336,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
|||
elementType,
|
||||
elementTypeSequential ? splatAttr
|
||||
: splatAttr.getSplatValue<Attribute>(),
|
||||
loc, moduleTranslation, false);
|
||||
loc, moduleTranslation);
|
||||
if (!child)
|
||||
return nullptr;
|
||||
if (llvmType->isVectorTy())
|
||||
|
@ -367,7 +367,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
|||
llvm::Type *innermostType = getInnermostElementType(llvmType);
|
||||
for (auto n : elementsAttr.getValues<Attribute>()) {
|
||||
constants.push_back(
|
||||
getLLVMConstant(innermostType, n, loc, moduleTranslation, false));
|
||||
getLLVMConstant(innermostType, n, loc, moduleTranslation));
|
||||
if (!constants.back())
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -41,14 +41,22 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
|
|||
|
||||
// -----
|
||||
|
||||
llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
|
||||
// expected-error @+1 {{nested struct types are not supported in constants}}
|
||||
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
|
||||
// expected-error @+1 {{expected struct type to be a complex number}}
|
||||
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
|
||||
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
|
||||
// expected-error @+1 {{expected struct type to be a complex number}}
|
||||
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
|
||||
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
|
||||
// expected-error @+1 {{FloatAttr does not match expected type of the constant}}
|
||||
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
|
||||
|
|
|
@ -1122,6 +1122,18 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
|
|||
llvm.return %1 : !llvm.struct<(i32, i32)>
|
||||
}
|
||||
|
||||
llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
|
||||
%1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
|
||||
// CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
|
||||
llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
|
||||
}
|
||||
|
||||
llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
|
||||
%1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
|
||||
// CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
|
||||
llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
|
||||
}
|
||||
|
||||
llvm.func @noreach() {
|
||||
// CHECK: unreachable
|
||||
llvm.unreachable
|
||||
|
|
Loading…
Reference in New Issue