diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h index 6f3026a5affb..2a8a8e7a1833 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h +++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h @@ -10,13 +10,12 @@ #define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Interfaces/VectorInterfaces.h" //===----------------------------------------------------------------------===// // Complex Dialect diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td index 4382183254ac..6cb7d9d92b58 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td @@ -19,7 +19,7 @@ def Complex_Dialect : Dialect { arithmetic ops. }]; - let dependentDialects = ["arith::ArithmeticDialect", "StandardOpsDialect"]; + let dependentDialects = ["arith::ArithmeticDialect"]; let hasConstantMaterializer = 1; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; } diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index 02b44ff16f56..a79d7ac8a215 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -10,6 +10,7 @@ #define COMPLEX_OPS include "mlir/Dialect/Complex/IR/ComplexBase.td" +include "mlir/IR/OpAsmInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -76,6 +77,40 @@ def AddOp : ComplexArithmeticOp<"add"> { }]; } +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +def ConstantOp : Complex_Op<"constant", [ + ConstantLike, NoSideEffect, + DeclareOpInterfaceMethods + ]> { + let summary = "complex number constant operation"; + let description = [{ + The `complex.constant` operation creates a constant complex number from an + attribute containing the real and imaginary parts. + + Example: + + ```mlir + %a = complex.constant [0.1, -1.0] : complex + ``` + }]; + + let arguments = (ins ArrayAttr:$value); + let results = (outs Complex:$complex); + + let assemblyFormat = "$value attr-dict `:` type($complex)"; + let hasFolder = 1; + let verifier = [{ return ::verify(*this); }]; + + let extraClassDeclaration = [{ + /// Returns true if a constant operation can be built with the given value + /// and result type. + static bool isBuildableWith(Attribute value, Type type); + }]; +} + //===----------------------------------------------------------------------===// // CreateOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt index fdb7e748a01b..da419d9b2599 100644 --- a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt @@ -11,6 +11,6 @@ add_mlir_dialect_library(MLIRComplex LINK_LIBS PUBLIC MLIRArithmetic MLIRDialect + MLIRInferTypeOpInterface MLIRIR - MLIRStandard ) diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp index a5aa9799f5c1..f189c2b2666d 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" using namespace mlir; @@ -25,9 +24,10 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - // TODO complex.constant - if (type.isa()) - return builder.create(loc, type, value); + if (complex::ConstantOp::isBuildableWith(value, type)) { + return builder.create(loc, type, + value.cast()); + } if (arith::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, value); return nullptr; diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index 58412d37605b..36745c5264b8 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -13,11 +13,54 @@ using namespace mlir; using namespace mlir::complex; //===----------------------------------------------------------------------===// -// TableGen'd op method definitions +// ConstantOp //===----------------------------------------------------------------------===// -#define GET_OP_CLASSES -#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" +OpFoldResult ConstantOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + return getValue(); +} + +void ConstantOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "cst"); +} + +bool ConstantOp::isBuildableWith(Attribute value, Type type) { + if (auto arrAttr = value.dyn_cast()) { + auto complexTy = type.dyn_cast(); + if (!complexTy) + return false; + auto complexEltTy = complexTy.getElementType(); + return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy && + arrAttr[1].getType() == complexEltTy; + } + return false; +} + +static LogicalResult verify(ConstantOp op) { + ArrayAttr arrayAttr = op.getValue(); + if (arrayAttr.size() != 2) { + return op.emitOpError( + "requires 'value' to be a complex constant, represented as array of " + "two values"); + } + + auto complexEltTy = op.getType().getElementType(); + if (complexEltTy != arrayAttr[0].getType() || + complexEltTy != arrayAttr[1].getType()) { + return op.emitOpError() + << "requires attribute's element types (" << arrayAttr[0].getType() + << ", " << arrayAttr[1].getType() + << ") to match the element type of the op's return type (" + << complexEltTy << ")"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// CreateOp +//===----------------------------------------------------------------------===// OpFoldResult CreateOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "binary op takes two operands"); @@ -32,6 +75,10 @@ OpFoldResult CreateOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// ImOp +//===----------------------------------------------------------------------===// + OpFoldResult ImOp::fold(ArrayRef operands) { assert(operands.size() == 1 && "unary op takes 1 operand"); ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); @@ -42,6 +89,10 @@ OpFoldResult ImOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// ReOp +//===----------------------------------------------------------------------===// + OpFoldResult ReOp::fold(ArrayRef operands) { assert(operands.size() == 1 && "unary op takes 1 operand"); ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); @@ -51,3 +102,10 @@ OpFoldResult ReOp::fold(ArrayRef operands) { return createOp.getOperand(0); return {}; } + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 111087939525..7487aea0902d 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -669,8 +669,8 @@ static void print(OpAsmPrinter &p, ConstantOp &op) { p << ' '; p << op.getValue(); - // If the value is a symbol reference or Array, print a trailing type. - if (op.getValue().isa()) + // If the value is a symbol reference, print a trailing type. + if (op.getValue().isa()) p << " : " << op.getType(); } @@ -681,10 +681,9 @@ static ParseResult parseConstantOp(OpAsmParser &parser, parser.parseAttribute(valueAttr, "value", result.attributes)) return failure(); - // If the attribute is a symbol reference or array, then we expect a trailing - // type. + // If the attribute is a symbol reference, then we expect a trailing type. Type type; - if (!valueAttr.isa()) + if (!valueAttr.isa()) type = valueAttr.getType(); else if (parser.parseColonType(type)) return failure(); @@ -705,24 +704,6 @@ static LogicalResult verify(ConstantOp &op) { return op.emitOpError() << "requires attribute's type (" << value.getType() << ") to match op's return type (" << type << ")"; - if (auto complexTy = type.dyn_cast()) { - auto arrayAttr = value.dyn_cast(); - if (!complexTy || arrayAttr.size() != 2) - return op.emitOpError( - "requires 'value' to be a complex constant, represented as array of " - "two values"); - auto complexEltTy = complexTy.getElementType(); - if (complexEltTy != arrayAttr[0].getType() || - complexEltTy != arrayAttr[1].getType()) { - return op.emitOpError() - << "requires attribute's element types (" << arrayAttr[0].getType() - << ", " << arrayAttr[1].getType() - << ") to match the element type of the op's return type (" - << complexEltTy << ")"; - } - return success(); - } - if (type.isa()) { auto fnAttr = value.dyn_cast(); if (!fnAttr) @@ -769,19 +750,8 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) { // SymbolRefAttr can only be used with a function type. if (value.isa()) return type.isa(); - // The attribute must have the same type as 'type'. - if (!value.getType().isa() && value.getType() != type) - return false; - // Finally, check that the attribute kind is handled. - if (auto arrAttr = value.dyn_cast()) { - auto complexTy = type.dyn_cast(); - if (!complexTy) - return false; - auto complexEltTy = complexTy.getElementType(); - return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy && - arrAttr[1].getType() == complexEltTy; - } - return value.isa(); + // Otherwise, this must be a UnitAttr. + return value.isa() && type.isa(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir index 038de9908cf2..c68d87e8c077 100644 --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -27,7 +27,7 @@ func @create_of_real_and_imag_different_operand( func @real_of_const() -> f32 { // CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-NEXT: return %[[CST]] : f32 - %complex = constant [1.0 : f32, 0.0 : f32] : complex + %complex = complex.constant [1.0 : f32, 0.0 : f32] : complex %1 = complex.re %complex : complex return %1 : f32 } @@ -47,7 +47,7 @@ func @real_of_create_op() -> f32 { func @imag_of_const() -> f32 { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: return %[[CST]] : f32 - %complex = constant [1.0 : f32, 0.0 : f32] : complex + %complex = complex.constant [1.0 : f32, 0.0 : f32] : complex %1 = complex.im %complex : complex return %1 : f32 } diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir new file mode 100644 index 000000000000..ec046effacf8 --- /dev/null +++ b/mlir/test/Dialect/Complex/invalid.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -split-input-file %s -verify-diagnostics + +func @complex_constant_wrong_array_attribute_length() { + // expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}} + %0 = complex.constant [1.0 : f32] : complex + return +} + +// ----- + +func @complex_constant_wrong_element_types() { + // expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}} + %0 = complex.constant [1.0 : f32, -1.0 : f32] : complex + return +} + +// ----- + +func @complex_constant_two_different_element_types() { + // expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}} + %0 = complex.constant [1.0 : f32, -1.0 : f64] : complex + return +} diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir index 3fc0e9299c0f..75bb082efb2a 100644 --- a/mlir/test/Dialect/Complex/ops.mlir +++ b/mlir/test/Dialect/Complex/ops.mlir @@ -5,6 +5,12 @@ // CHECK-LABEL: func @ops( // CHECK-SAME: %[[F:.*]]: f32) { func @ops(%f: f32) { + // CHECK: complex.constant [1.{{.*}}, -1.{{.*}}] : complex + %cst_f64 = complex.constant [0.1, -1.0] : complex + + // CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex + %cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex + // CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex %complex = complex.create %f, %f : complex @@ -51,4 +57,3 @@ func @ops(%f: f32) { %diff = complex.sub %complex, %complex : complex return } - diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir index defd1aed97b1..836158dd2160 100644 --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -8,30 +8,6 @@ func @unsupported_attribute() { // ----- -func @complex_constant_wrong_array_attribute_length() { - // expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}} - %0 = constant [1.0 : f32] : complex - return -} - -// ----- - -func @complex_constant_wrong_element_types() { - // expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}} - %0 = constant [1.0 : f32, -1.0 : f32] : complex - return -} - -// ----- - -func @complex_constant_two_different_element_types() { - // expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}} - %0 = constant [1.0 : f32, -1.0 : f64] : complex - return -} - -// ----- - func @return_i32_f32() -> (i32, f32) { %0 = arith.constant 1 : i32 %1 = arith.constant 1. : f32 diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir index 64322f066b35..1c3570eedb7e 100644 --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -51,18 +51,6 @@ func @switch_i64(%flag : i64, %caseOperand : i32) { return } -// CHECK-LABEL: func @constant_complex_f32( -func @constant_complex_f32() -> complex { - %result = constant [0.1 : f32, -1.0 : f32] : complex - return %result : complex -} - -// CHECK-LABEL: func @constant_complex_f64( -func @constant_complex_f64() -> complex { - %result = constant [0.1 : f64, -1.0 : f64] : complex - return %result : complex -} - // CHECK-LABEL: func @vector_splat_0d( func @vector_splat_0d(%a: f32) -> vector { // CHECK: splat %{{.*}} : vector