forked from OSchip/llvm-project
Add support for complex constants to MLIR core.
BEGIN_PUBLIC Add support for complex constants to MLIR core. END_PUBLIC Differential Revision: https://reviews.llvm.org/D101908
This commit is contained in:
parent
7ddeffee55
commit
5ef21506b9
|
@ -1092,7 +1092,10 @@ def ConstantOp : Std_Op<"constant",
|
|||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Attribute":$value),
|
||||
[{ build($_builder, $_state, value.getType(), value); }]>];
|
||||
[{ build($_builder, $_state, value.getType(), value); }]>,
|
||||
OpBuilder<(ins "Attribute":$value, "Type":$type),
|
||||
[{ build($_builder, $_state, type, value); }]>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Attribute getValue() { return (*this)->getAttr("value"); }
|
||||
|
|
|
@ -326,11 +326,13 @@ SetVector<Block *> getTopologicallySortedBlocks(Region ®ion);
|
|||
|
||||
/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
|
||||
/// This currently supports integer, floating point, splat and dense element
|
||||
/// attributes and combinations thereof. In case of error, report it to `loc`
|
||||
/// and return nullptr.
|
||||
/// attributes and combinations thereof. Also, an array attribute with two
|
||||
/// elements is supported to represent a complex constant. In case of error,
|
||||
/// report it to `loc` and return nullptr.
|
||||
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
|
||||
Location loc,
|
||||
const ModuleTranslation &moduleTranslation);
|
||||
const ModuleTranslation &moduleTranslation,
|
||||
bool isTopLevel = true);
|
||||
|
||||
/// Creates a call to an LLVM IR intrinsic function with the given arguments.
|
||||
llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
|
||||
|
|
|
@ -1962,7 +1962,30 @@ static LogicalResult verify(LLVM::ConstantOp op) {
|
|||
}
|
||||
return success();
|
||||
}
|
||||
if (!op.value().isa<IntegerAttr, FloatAttr, ElementsAttr>())
|
||||
if (auto structType = op.getType().dyn_cast<LLVMStructType>()) {
|
||||
if (structType.getBody().size() != 2 ||
|
||||
structType.getBody()[0] != structType.getBody()[1]) {
|
||||
return op.emitError() << "expected struct type with two elements of the "
|
||||
"same type, the type of a complex constant";
|
||||
}
|
||||
|
||||
auto arrayAttr = op.value().dyn_cast<ArrayAttr>();
|
||||
if (!arrayAttr || arrayAttr.size() != 2 ||
|
||||
arrayAttr[0].getType() != arrayAttr[1].getType()) {
|
||||
return op.emitOpError() << "expected array attribute with two elements, "
|
||||
"representing a complex constant";
|
||||
}
|
||||
|
||||
Type elementType = structType.getBody()[0];
|
||||
if (!elementType
|
||||
.isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
|
||||
return op.emitError()
|
||||
<< "expected struct element types to be floating point type or "
|
||||
"integer type";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
if (!op.value().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
|
||||
return op.emitOpError()
|
||||
<< "only supports integer, float, string or elements attributes";
|
||||
return success();
|
||||
|
|
|
@ -1200,7 +1200,8 @@ public:
|
|||
|
||||
// Create a constant scalar value from the splat constant.
|
||||
Value scalarConstant = rewriter.create<ConstantOp>(
|
||||
def->getLoc(), constantAttr.getSplatValue());
|
||||
def->getLoc(), constantAttr.getSplatValue(),
|
||||
constantAttr.getType().getElementType());
|
||||
|
||||
auto fusedOp = rewriter.create<GenericOp>(
|
||||
rewriter.getUnknownLoc(), genericOp->getResultTypes(),
|
||||
|
|
|
@ -1067,8 +1067,8 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
|
|||
p << ' ';
|
||||
p << op.getValue();
|
||||
|
||||
// If the value is a symbol reference, print a trailing type.
|
||||
if (op.getValue().isa<SymbolRefAttr>())
|
||||
// If the value is a symbol reference or Array, print a trailing type.
|
||||
if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
||||
|
@ -1079,9 +1079,10 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
|
|||
parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||
return failure();
|
||||
|
||||
// If the attribute is a symbol reference, then we expect a trailing type.
|
||||
// If the attribute is a symbol reference or array, then we expect a trailing
|
||||
// type.
|
||||
Type type;
|
||||
if (!valueAttr.isa<SymbolRefAttr>())
|
||||
if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
|
||||
type = valueAttr.getType();
|
||||
else if (parser.parseColonType(type))
|
||||
return failure();
|
||||
|
@ -1119,6 +1120,24 @@ static LogicalResult verify(ConstantOp &op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
if (auto complexTy = type.dyn_cast<ComplexType>()) {
|
||||
auto arrayAttr = value.dyn_cast<ArrayAttr>();
|
||||
if (!complexTy || arrayAttr.size() != 2)
|
||||
return op.emitOpError(
|
||||
"requires 'value' to be a complex constant, represented as array of "
|
||||
"two values");
|
||||
auto complexEltTy = complexTy.getElementType();
|
||||
if (complexEltTy != arrayAttr[0].getType() ||
|
||||
complexEltTy != arrayAttr[1].getType()) {
|
||||
return op.emitOpError()
|
||||
<< "requires attribute's element types (" << arrayAttr[0].getType()
|
||||
<< ", " << arrayAttr[1].getType()
|
||||
<< ") to match the element type of the op's return type ("
|
||||
<< complexEltTy << ")";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
if (type.isa<FloatType>()) {
|
||||
if (!value.isa<FloatAttr>())
|
||||
return op.emitOpError("requires 'value' to be a floating point constant");
|
||||
|
@ -1193,13 +1212,21 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
|
|||
if (value.isa<SymbolRefAttr>())
|
||||
return type.isa<FunctionType>();
|
||||
// The attribute must have the same type as 'type'.
|
||||
if (value.getType() != type)
|
||||
if (!value.getType().isa<NoneType>() && value.getType() != type)
|
||||
return false;
|
||||
// If the type is an integer type, it must be signless.
|
||||
if (IntegerType integerTy = type.dyn_cast<IntegerType>())
|
||||
if (!integerTy.isSignless())
|
||||
return false;
|
||||
// Finally, check that the attribute kind is handled.
|
||||
if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
|
||||
auto complexTy = type.dyn_cast<ComplexType>();
|
||||
if (!complexTy)
|
||||
return false;
|
||||
auto complexEltTy = complexTy.getElementType();
|
||||
return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
|
||||
arrAttr[1].getType() == complexEltTy;
|
||||
}
|
||||
return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
|
||||
}
|
||||
|
||||
|
|
|
@ -578,6 +578,25 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
|
|||
FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
|
||||
return FloatAttr::get(eltTy, *floatIt);
|
||||
}
|
||||
if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
|
||||
auto complexEltTy = complexTy.getElementType();
|
||||
ComplexIntElementIterator complexIntIt(owner, index);
|
||||
if (complexEltTy.isa<IntegerType>()) {
|
||||
auto value = *complexIntIt;
|
||||
auto real = IntegerAttr::get(complexEltTy, value.real());
|
||||
auto imag = IntegerAttr::get(complexEltTy, value.imag());
|
||||
return ArrayAttr::get(complexTy.getContext(),
|
||||
ArrayRef<Attribute>{real, imag});
|
||||
}
|
||||
|
||||
ComplexFloatElementIterator complexFloatIt(
|
||||
complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
|
||||
auto value = *complexFloatIt;
|
||||
auto real = FloatAttr::get(complexEltTy, value.real());
|
||||
auto imag = FloatAttr::get(complexEltTy, value.imag());
|
||||
return ArrayAttr::get(complexTy.getContext(),
|
||||
ArrayRef<Attribute>{real, imag});
|
||||
}
|
||||
if (owner.isa<DenseStringElementsAttr>()) {
|
||||
ArrayRef<StringRef> vals = owner.getRawStringData();
|
||||
return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
|
||||
|
|
|
@ -103,16 +103,30 @@ static llvm::Type *getInnermostElementType(llvm::Type *type) {
|
|||
|
||||
/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
|
||||
/// This currently supports integer, floating point, splat and dense element
|
||||
/// attributes and combinations thereof. In case of error, report it to `loc`
|
||||
/// and return nullptr.
|
||||
/// attributes and combinations thereof. Also, an array attribute with two
|
||||
/// elements is supported to represent a complex constant. In case of error,
|
||||
/// report it to `loc` and return nullptr.
|
||||
llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
||||
llvm::Type *llvmType, Attribute attr, Location loc,
|
||||
const ModuleTranslation &moduleTranslation) {
|
||||
const ModuleTranslation &moduleTranslation, bool isTopLevel) {
|
||||
if (!attr)
|
||||
return llvm::UndefValue::get(llvmType);
|
||||
if (llvmType->isStructTy()) {
|
||||
emitError(loc, "struct types are not supported in constants");
|
||||
return nullptr;
|
||||
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
|
||||
if (!isTopLevel) {
|
||||
emitError(loc, "nested struct types are not supported in constants");
|
||||
return nullptr;
|
||||
}
|
||||
auto arrayAttr = attr.cast<ArrayAttr>();
|
||||
llvm::Type *elementType = structType->getElementType(0);
|
||||
llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc,
|
||||
moduleTranslation, false);
|
||||
if (!real)
|
||||
return nullptr;
|
||||
llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc,
|
||||
moduleTranslation, false);
|
||||
if (!imag)
|
||||
return nullptr;
|
||||
return llvm::ConstantStruct::get(structType, {real, imag});
|
||||
}
|
||||
// For integer types, we allow a mismatch in sizes as the index type in
|
||||
// MLIR might have a different size than the index type in the LLVM module.
|
||||
|
@ -120,8 +134,15 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
|||
return llvm::ConstantInt::get(
|
||||
llvmType,
|
||||
intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
|
||||
if (auto floatAttr = attr.dyn_cast<FloatAttr>())
|
||||
if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
|
||||
if (llvmType !=
|
||||
llvm::Type::getFloatingPointTy(llvmType->getContext(),
|
||||
floatAttr.getValue().getSemantics())) {
|
||||
emitError(loc, "FloatAttr does not match expected type of the constant");
|
||||
return nullptr;
|
||||
}
|
||||
return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
|
||||
}
|
||||
if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>())
|
||||
return llvm::ConstantExpr::getBitCast(
|
||||
moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType);
|
||||
|
@ -144,7 +165,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
|||
llvm::Constant *child = getLLVMConstant(
|
||||
elementType,
|
||||
elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), loc,
|
||||
moduleTranslation);
|
||||
moduleTranslation, false);
|
||||
if (!child)
|
||||
return nullptr;
|
||||
if (llvmType->isVectorTy())
|
||||
|
@ -169,7 +190,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
|||
llvm::Type *innermostType = getInnermostElementType(llvmType);
|
||||
for (auto n : elementsAttr.getValues<Attribute>()) {
|
||||
constants.push_back(
|
||||
getLLVMConstant(innermostType, n, loc, moduleTranslation));
|
||||
getLLVMConstant(innermostType, n, loc, moduleTranslation, false));
|
||||
if (!constants.back())
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -265,6 +265,54 @@ func @constant_wrong_type_string() {
|
|||
|
||||
// -----
|
||||
|
||||
llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
|
||||
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
|
||||
%0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
|
||||
llvm.return %0 : !llvm.struct<(f64, f64)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
|
||||
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
|
||||
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
|
||||
llvm.return %0 : !llvm.struct<(f64, f64)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
|
||||
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
|
||||
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
|
||||
llvm.return %0 : !llvm.struct<(f64, f64)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
|
||||
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
|
||||
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
|
||||
llvm.return %0 : !llvm.struct<(f64)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
|
||||
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
|
||||
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f64, f32)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)> {
|
||||
// expected-error @+1 {{expected struct element types to be floating point type or integer type}}
|
||||
%0 = llvm.mlir.constant([dense<[1.0, 1.0]> : tensor<2xf64>, dense<[1.0, 1.0]> : tensor<2xf64>]) : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)>
|
||||
llvm.return %0 : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
|
||||
// expected-error@+1 {{expected LLVM IR Dialect type}}
|
||||
llvm.insertvalue %a, %b[0] : tensor<*xi32>
|
||||
|
|
|
@ -37,3 +37,35 @@ func @unsupported_attribute() {
|
|||
%0 = constant "" : index
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @complex_constant_wrong_array_attribute_length() {
|
||||
// expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
|
||||
%0 = constant [1.0 : f32] : complex<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @complex_constant_wrong_attribute_type() {
|
||||
// expected-error @+1 {{requires attribute's type ('f32') to match op's return type ('complex<f32>')}}
|
||||
%0 = "std.constant" () {value = 1.0 : f32} : () -> complex<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @complex_constant_wrong_element_types() {
|
||||
// expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
|
||||
%0 = constant [1.0 : f32, -1.0 : f32] : complex<f64>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @complex_constant_two_different_element_types() {
|
||||
// expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
|
||||
%0 = constant [1.0 : f32, -1.0 : f64] : complex<f64>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -68,3 +68,15 @@ func @switch_i64(%flag : i64, %caseOperand : i32) {
|
|||
^bb3(%bb3arg : i32):
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @constant_complex_f32(
|
||||
func @constant_complex_f32() -> complex<f32> {
|
||||
%result = constant [0.1 : f32, -1.0 : f32] : complex<f32>
|
||||
return %result : complex<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @constant_complex_f64(
|
||||
func @constant_complex_f64() -> complex<f64> {
|
||||
%result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
|
||||
return %result : complex<f64>
|
||||
}
|
||||
|
|
|
@ -35,13 +35,21 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
|
|||
// -----
|
||||
|
||||
llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
|
||||
// expected-error @+1 {{struct types are not supported in constants}}
|
||||
// expected-error @+1 {{nested struct types are not supported in constants}}
|
||||
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
|
||||
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
|
||||
// expected-error @+1 {{FloatAttr does not match expected type of the constant}}
|
||||
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
|
||||
llvm.return %0 : !llvm.struct<(f64, f64)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{unsupported constant value}}
|
||||
llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64>
|
||||
|
||||
|
@ -63,4 +71,6 @@ llvm.func @passthrough_wrong_type() attributes {passthrough = [42]}
|
|||
// -----
|
||||
|
||||
// expected-error @+1 {{expected arrays within 'passthrough' to contain two strings}}
|
||||
llvm.func @passthrough_wrong_type() attributes {passthrough = [[42, 42]]}
|
||||
llvm.func @passthrough_wrong_type() attributes {
|
||||
passthrough = [[ 42, 42 ]]
|
||||
}
|
||||
|
|
|
@ -1016,6 +1016,18 @@ llvm.func @stringconstant() -> !llvm.array<12 x i8> {
|
|||
llvm.return %1 : !llvm.array<12 x i8>
|
||||
}
|
||||
|
||||
llvm.func @complexfpconstant() -> !llvm.struct<(f32, f32)> {
|
||||
%1 = llvm.mlir.constant([-1.000000e+00 : f32, 0.000000e+00 : f32]) : !llvm.struct<(f32, f32)>
|
||||
// CHECK: ret { float, float } { float -1.000000e+00, float 0.000000e+00 }
|
||||
llvm.return %1 : !llvm.struct<(f32, f32)>
|
||||
}
|
||||
|
||||
llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
|
||||
%1 = llvm.mlir.constant([-1 : i32, 0 : i32]) : !llvm.struct<(i32, i32)>
|
||||
// CHECK: ret { i32, i32 } { i32 -1, i32 0 }
|
||||
llvm.return %1 : !llvm.struct<(i32, i32)>
|
||||
}
|
||||
|
||||
llvm.func @noreach() {
|
||||
// CHECK: unreachable
|
||||
llvm.unreachable
|
||||
|
|
Loading…
Reference in New Issue