Add support for complex constants to MLIR core.

BEGIN_PUBLIC
Add support for complex constants to MLIR core.
END_PUBLIC

Differential Revision: https://reviews.llvm.org/D101908
This commit is contained in:
Adrian Kuegel 2021-05-14 13:00:38 +02:00
parent 7ddeffee55
commit 5ef21506b9
12 changed files with 232 additions and 22 deletions

View File

@ -1092,7 +1092,10 @@ def ConstantOp : Std_Op<"constant",
let builders = [
OpBuilder<(ins "Attribute":$value),
[{ build($_builder, $_state, value.getType(), value); }]>];
[{ build($_builder, $_state, value.getType(), value); }]>,
OpBuilder<(ins "Attribute":$value, "Type":$type),
[{ build($_builder, $_state, type, value); }]>,
];
let extraClassDeclaration = [{
Attribute getValue() { return (*this)->getAttr("value"); }

View File

@ -326,11 +326,13 @@ SetVector<Block *> getTopologicallySortedBlocks(Region &region);
/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
/// This currently supports integer, floating point, splat and dense element
/// attributes and combinations thereof. In case of error, report it to `loc`
/// and return nullptr.
/// attributes and combinations thereof. Also, an array attribute with two
/// elements is supported to represent a complex constant. In case of error,
/// report it to `loc` and return nullptr.
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
Location loc,
const ModuleTranslation &moduleTranslation);
const ModuleTranslation &moduleTranslation,
bool isTopLevel = true);
/// Creates a call to an LLVM IR intrinsic function with the given arguments.
llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,

View File

@ -1962,7 +1962,30 @@ static LogicalResult verify(LLVM::ConstantOp op) {
}
return success();
}
if (!op.value().isa<IntegerAttr, FloatAttr, ElementsAttr>())
if (auto structType = op.getType().dyn_cast<LLVMStructType>()) {
if (structType.getBody().size() != 2 ||
structType.getBody()[0] != structType.getBody()[1]) {
return op.emitError() << "expected struct type with two elements of the "
"same type, the type of a complex constant";
}
auto arrayAttr = op.value().dyn_cast<ArrayAttr>();
if (!arrayAttr || arrayAttr.size() != 2 ||
arrayAttr[0].getType() != arrayAttr[1].getType()) {
return op.emitOpError() << "expected array attribute with two elements, "
"representing a complex constant";
}
Type elementType = structType.getBody()[0];
if (!elementType
.isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
return op.emitError()
<< "expected struct element types to be floating point type or "
"integer type";
}
return success();
}
if (!op.value().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
return op.emitOpError()
<< "only supports integer, float, string or elements attributes";
return success();

View File

@ -1200,7 +1200,8 @@ public:
// Create a constant scalar value from the splat constant.
Value scalarConstant = rewriter.create<ConstantOp>(
def->getLoc(), constantAttr.getSplatValue());
def->getLoc(), constantAttr.getSplatValue(),
constantAttr.getType().getElementType());
auto fusedOp = rewriter.create<GenericOp>(
rewriter.getUnknownLoc(), genericOp->getResultTypes(),

View File

@ -1067,8 +1067,8 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
p << ' ';
p << op.getValue();
// If the value is a symbol reference, print a trailing type.
if (op.getValue().isa<SymbolRefAttr>())
// If the value is a symbol reference or Array, print a trailing type.
if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
p << " : " << op.getType();
}
@ -1079,9 +1079,10 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
// If the attribute is a symbol reference, then we expect a trailing type.
// If the attribute is a symbol reference or array, then we expect a trailing
// type.
Type type;
if (!valueAttr.isa<SymbolRefAttr>())
if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
type = valueAttr.getType();
else if (parser.parseColonType(type))
return failure();
@ -1119,6 +1120,24 @@ static LogicalResult verify(ConstantOp &op) {
return success();
}
if (auto complexTy = type.dyn_cast<ComplexType>()) {
auto arrayAttr = value.dyn_cast<ArrayAttr>();
if (!complexTy || arrayAttr.size() != 2)
return op.emitOpError(
"requires 'value' to be a complex constant, represented as array of "
"two values");
auto complexEltTy = complexTy.getElementType();
if (complexEltTy != arrayAttr[0].getType() ||
complexEltTy != arrayAttr[1].getType()) {
return op.emitOpError()
<< "requires attribute's element types (" << arrayAttr[0].getType()
<< ", " << arrayAttr[1].getType()
<< ") to match the element type of the op's return type ("
<< complexEltTy << ")";
}
return success();
}
if (type.isa<FloatType>()) {
if (!value.isa<FloatAttr>())
return op.emitOpError("requires 'value' to be a floating point constant");
@ -1193,13 +1212,21 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
if (value.isa<SymbolRefAttr>())
return type.isa<FunctionType>();
// The attribute must have the same type as 'type'.
if (value.getType() != type)
if (!value.getType().isa<NoneType>() && value.getType() != type)
return false;
// If the type is an integer type, it must be signless.
if (IntegerType integerTy = type.dyn_cast<IntegerType>())
if (!integerTy.isSignless())
return false;
// Finally, check that the attribute kind is handled.
if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
auto complexTy = type.dyn_cast<ComplexType>();
if (!complexTy)
return false;
auto complexEltTy = complexTy.getElementType();
return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
arrAttr[1].getType() == complexEltTy;
}
return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
}

View File

@ -578,6 +578,25 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
return FloatAttr::get(eltTy, *floatIt);
}
if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
auto complexEltTy = complexTy.getElementType();
ComplexIntElementIterator complexIntIt(owner, index);
if (complexEltTy.isa<IntegerType>()) {
auto value = *complexIntIt;
auto real = IntegerAttr::get(complexEltTy, value.real());
auto imag = IntegerAttr::get(complexEltTy, value.imag());
return ArrayAttr::get(complexTy.getContext(),
ArrayRef<Attribute>{real, imag});
}
ComplexFloatElementIterator complexFloatIt(
complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
auto value = *complexFloatIt;
auto real = FloatAttr::get(complexEltTy, value.real());
auto imag = FloatAttr::get(complexEltTy, value.imag());
return ArrayAttr::get(complexTy.getContext(),
ArrayRef<Attribute>{real, imag});
}
if (owner.isa<DenseStringElementsAttr>()) {
ArrayRef<StringRef> vals = owner.getRawStringData();
return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);

View File

@ -103,16 +103,30 @@ static llvm::Type *getInnermostElementType(llvm::Type *type) {
/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
/// This currently supports integer, floating point, splat and dense element
/// attributes and combinations thereof. In case of error, report it to `loc`
/// and return nullptr.
/// attributes and combinations thereof. Also, an array attribute with two
/// elements is supported to represent a complex constant. In case of error,
/// report it to `loc` and return nullptr.
llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::Type *llvmType, Attribute attr, Location loc,
const ModuleTranslation &moduleTranslation) {
const ModuleTranslation &moduleTranslation, bool isTopLevel) {
if (!attr)
return llvm::UndefValue::get(llvmType);
if (llvmType->isStructTy()) {
emitError(loc, "struct types are not supported in constants");
return nullptr;
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
if (!isTopLevel) {
emitError(loc, "nested struct types are not supported in constants");
return nullptr;
}
auto arrayAttr = attr.cast<ArrayAttr>();
llvm::Type *elementType = structType->getElementType(0);
llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc,
moduleTranslation, false);
if (!real)
return nullptr;
llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc,
moduleTranslation, false);
if (!imag)
return nullptr;
return llvm::ConstantStruct::get(structType, {real, imag});
}
// For integer types, we allow a mismatch in sizes as the index type in
// MLIR might have a different size than the index type in the LLVM module.
@ -120,8 +134,15 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
return llvm::ConstantInt::get(
llvmType,
intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
if (auto floatAttr = attr.dyn_cast<FloatAttr>())
if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
if (llvmType !=
llvm::Type::getFloatingPointTy(llvmType->getContext(),
floatAttr.getValue().getSemantics())) {
emitError(loc, "FloatAttr does not match expected type of the constant");
return nullptr;
}
return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
}
if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>())
return llvm::ConstantExpr::getBitCast(
moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType);
@ -144,7 +165,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::Constant *child = getLLVMConstant(
elementType,
elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), loc,
moduleTranslation);
moduleTranslation, false);
if (!child)
return nullptr;
if (llvmType->isVectorTy())
@ -169,7 +190,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
llvm::Type *innermostType = getInnermostElementType(llvmType);
for (auto n : elementsAttr.getValues<Attribute>()) {
constants.push_back(
getLLVMConstant(innermostType, n, loc, moduleTranslation));
getLLVMConstant(innermostType, n, loc, moduleTranslation, false));
if (!constants.back())
return nullptr;
}

View File

@ -265,6 +265,54 @@ func @constant_wrong_type_string() {
// -----
llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
%0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
// -----
llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
// -----
llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{expected array attribute with two elements, representing a complex constant}}
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
// -----
llvm.func @struct_one_element() -> !llvm.struct<(f64)> {
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)>
llvm.return %0 : !llvm.struct<(f64)>
}
// -----
llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
// expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}}
%0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)>
llvm.return %0 : !llvm.struct<(f64, f32)>
}
// -----
llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)> {
// expected-error @+1 {{expected struct element types to be floating point type or integer type}}
%0 = llvm.mlir.constant([dense<[1.0, 1.0]> : tensor<2xf64>, dense<[1.0, 1.0]> : tensor<2xf64>]) : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)>
llvm.return %0 : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)>
}
// -----
func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
// expected-error@+1 {{expected LLVM IR Dialect type}}
llvm.insertvalue %a, %b[0] : tensor<*xi32>

View File

@ -37,3 +37,35 @@ func @unsupported_attribute() {
%0 = constant "" : index
return
}
// -----
func @complex_constant_wrong_array_attribute_length() {
// expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
%0 = constant [1.0 : f32] : complex<f32>
return
}
// -----
func @complex_constant_wrong_attribute_type() {
// expected-error @+1 {{requires attribute's type ('f32') to match op's return type ('complex<f32>')}}
%0 = "std.constant" () {value = 1.0 : f32} : () -> complex<f32>
return
}
// -----
func @complex_constant_wrong_element_types() {
// expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
%0 = constant [1.0 : f32, -1.0 : f32] : complex<f64>
return
}
// -----
func @complex_constant_two_different_element_types() {
// expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
%0 = constant [1.0 : f32, -1.0 : f64] : complex<f64>
return
}

View File

@ -68,3 +68,15 @@ func @switch_i64(%flag : i64, %caseOperand : i32) {
^bb3(%bb3arg : i32):
return
}
// CHECK-LABEL: func @constant_complex_f32(
func @constant_complex_f32() -> complex<f32> {
%result = constant [0.1 : f32, -1.0 : f32] : complex<f32>
return %result : complex<f32>
}
// CHECK-LABEL: func @constant_complex_f64(
func @constant_complex_f64() -> complex<f64> {
%result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
return %result : complex<f64>
}

View File

@ -35,13 +35,21 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
// -----
llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
// expected-error @+1 {{struct types are not supported in constants}}
// expected-error @+1 {{nested struct types are not supported in constants}}
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
}
// -----
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
// expected-error @+1 {{FloatAttr does not match expected type of the constant}}
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
}
// -----
// expected-error @+1 {{unsupported constant value}}
llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64>
@ -63,4 +71,6 @@ llvm.func @passthrough_wrong_type() attributes {passthrough = [42]}
// -----
// expected-error @+1 {{expected arrays within 'passthrough' to contain two strings}}
llvm.func @passthrough_wrong_type() attributes {passthrough = [[42, 42]]}
llvm.func @passthrough_wrong_type() attributes {
passthrough = [[ 42, 42 ]]
}

View File

@ -1016,6 +1016,18 @@ llvm.func @stringconstant() -> !llvm.array<12 x i8> {
llvm.return %1 : !llvm.array<12 x i8>
}
llvm.func @complexfpconstant() -> !llvm.struct<(f32, f32)> {
%1 = llvm.mlir.constant([-1.000000e+00 : f32, 0.000000e+00 : f32]) : !llvm.struct<(f32, f32)>
// CHECK: ret { float, float } { float -1.000000e+00, float 0.000000e+00 }
llvm.return %1 : !llvm.struct<(f32, f32)>
}
llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
%1 = llvm.mlir.constant([-1 : i32, 0 : i32]) : !llvm.struct<(i32, i32)>
// CHECK: ret { i32, i32 } { i32 -1, i32 0 }
llvm.return %1 : !llvm.struct<(i32, i32)>
}
llvm.func @noreach() {
// CHECK: unreachable
llvm.unreachable