forked from OSchip/llvm-project
[mlir][LLVM] Make the nested type restriction on complex constants less aggressive
Complex nested in other types is perfectly fine, just nested structs aren't supported. Instead of checking whether there's nesting just check whether the struct we're dealing with is a complex number. Differential Revision: https://reviews.llvm.org/D125381
This commit is contained in:
parent
2a40cc532b
commit
27dad99622
|
@ -351,8 +351,7 @@ SetVector<Block *> getTopologicallySortedBlocks(Region ®ion);
|
||||||
/// report it to `loc` and return nullptr.
|
/// report it to `loc` and return nullptr.
|
||||||
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
|
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
|
||||||
Location loc,
|
Location loc,
|
||||||
const ModuleTranslation &moduleTranslation,
|
const ModuleTranslation &moduleTranslation);
|
||||||
bool isTopLevel = true);
|
|
||||||
|
|
||||||
/// Creates a call to an LLVM IR intrinsic function with the given arguments.
|
/// Creates a call to an LLVM IR intrinsic function with the given arguments.
|
||||||
llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
|
llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
|
||||||
|
|
|
@ -214,7 +214,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
|
||||||
(type.isa<VectorType>() || hasVectorElementType)) {
|
(type.isa<VectorType>() || hasVectorElementType)) {
|
||||||
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
|
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
|
||||||
innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
|
innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
|
||||||
moduleTranslation, /*isTopLevel=*/false);
|
moduleTranslation);
|
||||||
llvm::Constant *splatVector =
|
llvm::Constant *splatVector =
|
||||||
llvm::ConstantDataVector::getSplat(0, splatValue);
|
llvm::ConstantDataVector::getSplat(0, splatValue);
|
||||||
SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
|
SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
|
||||||
|
@ -272,22 +272,22 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
|
||||||
/// report it to `loc` and return nullptr.
|
/// report it to `loc` and return nullptr.
|
||||||
llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
||||||
llvm::Type *llvmType, Attribute attr, Location loc,
|
llvm::Type *llvmType, Attribute attr, Location loc,
|
||||||
const ModuleTranslation &moduleTranslation, bool isTopLevel) {
|
const ModuleTranslation &moduleTranslation) {
|
||||||
if (!attr)
|
if (!attr)
|
||||||
return llvm::UndefValue::get(llvmType);
|
return llvm::UndefValue::get(llvmType);
|
||||||
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
|
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
|
||||||
if (!isTopLevel) {
|
auto arrayAttr = attr.dyn_cast<ArrayAttr>();
|
||||||
emitError(loc, "nested struct types are not supported in constants");
|
if (!arrayAttr || arrayAttr.size() != 2) {
|
||||||
|
emitError(loc, "expected struct type to be a complex number");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto arrayAttr = attr.cast<ArrayAttr>();
|
|
||||||
llvm::Type *elementType = structType->getElementType(0);
|
llvm::Type *elementType = structType->getElementType(0);
|
||||||
llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc,
|
llvm::Constant *real =
|
||||||
moduleTranslation, false);
|
getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
|
||||||
if (!real)
|
if (!real)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc,
|
llvm::Constant *imag =
|
||||||
moduleTranslation, false);
|
getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
|
||||||
if (!imag)
|
if (!imag)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return llvm::ConstantStruct::get(structType, {real, imag});
|
return llvm::ConstantStruct::get(structType, {real, imag});
|
||||||
|
@ -336,7 +336,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
||||||
elementType,
|
elementType,
|
||||||
elementTypeSequential ? splatAttr
|
elementTypeSequential ? splatAttr
|
||||||
: splatAttr.getSplatValue<Attribute>(),
|
: splatAttr.getSplatValue<Attribute>(),
|
||||||
loc, moduleTranslation, false);
|
loc, moduleTranslation);
|
||||||
if (!child)
|
if (!child)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (llvmType->isVectorTy())
|
if (llvmType->isVectorTy())
|
||||||
|
@ -367,7 +367,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
||||||
llvm::Type *innermostType = getInnermostElementType(llvmType);
|
llvm::Type *innermostType = getInnermostElementType(llvmType);
|
||||||
for (auto n : elementsAttr.getValues<Attribute>()) {
|
for (auto n : elementsAttr.getValues<Attribute>()) {
|
||||||
constants.push_back(
|
constants.push_back(
|
||||||
getLLVMConstant(innermostType, n, loc, moduleTranslation, false));
|
getLLVMConstant(innermostType, n, loc, moduleTranslation));
|
||||||
if (!constants.back())
|
if (!constants.back())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,14 +41,22 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
|
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
|
||||||
// expected-error @+1 {{nested struct types are not supported in constants}}
|
// expected-error @+1 {{expected struct type to be a complex number}}
|
||||||
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
|
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
|
||||||
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
|
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
|
||||||
|
// expected-error @+1 {{expected struct type to be a complex number}}
|
||||||
|
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
|
||||||
|
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
|
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
|
||||||
// expected-error @+1 {{FloatAttr does not match expected type of the constant}}
|
// expected-error @+1 {{FloatAttr does not match expected type of the constant}}
|
||||||
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
|
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
|
||||||
|
|
|
@ -1122,6 +1122,18 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
|
||||||
llvm.return %1 : !llvm.struct<(i32, i32)>
|
llvm.return %1 : !llvm.struct<(i32, i32)>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
|
||||||
|
%1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
|
||||||
|
// CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
|
||||||
|
llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
|
||||||
|
%1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
|
||||||
|
// CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
|
||||||
|
llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>>
|
||||||
|
}
|
||||||
|
|
||||||
llvm.func @noreach() {
|
llvm.func @noreach() {
|
||||||
// CHECK: unreachable
|
// CHECK: unreachable
|
||||||
llvm.unreachable
|
llvm.unreachable
|
||||||
|
|
Loading…
Reference in New Issue