forked from OSchip/llvm-project
[mlir] Move the complex support of std.constant to a new complex.constant operation
This is part of splitting up the standard dialect. Differential Revision: https://reviews.llvm.org/D118182
This commit is contained in:
parent
b88a4d72d9
commit
480cd4cb85
|
@ -10,13 +10,12 @@
|
|||
#define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Complex Dialect
|
||||
|
|
|
@ -19,7 +19,7 @@ def Complex_Dialect : Dialect {
|
|||
arithmetic ops.
|
||||
}];
|
||||
|
||||
let dependentDialects = ["arith::ArithmeticDialect", "StandardOpsDialect"];
|
||||
let dependentDialects = ["arith::ArithmeticDialect"];
|
||||
let hasConstantMaterializer = 1;
|
||||
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#define COMPLEX_OPS
|
||||
|
||||
include "mlir/Dialect/Complex/IR/ComplexBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
|
@ -76,6 +77,40 @@ def AddOp : ComplexArithmeticOp<"add"> {
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConstantOp : Complex_Op<"constant", [
|
||||
ConstantLike, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
|
||||
]> {
|
||||
let summary = "complex number constant operation";
|
||||
let description = [{
|
||||
The `complex.constant` operation creates a constant complex number from an
|
||||
attribute containing the real and imaginary parts.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%a = complex.constant [0.1, -1.0] : complex<f64>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins ArrayAttr:$value);
|
||||
let results = (outs Complex<AnyFloat>:$complex);
|
||||
|
||||
let assemblyFormat = "$value attr-dict `:` type($complex)";
|
||||
let hasFolder = 1;
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns true if a constant operation can be built with the given value
|
||||
/// and result type.
|
||||
static bool isBuildableWith(Attribute value, Type type);
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CreateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -11,6 +11,6 @@ add_mlir_dialect_library(MLIRComplex
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRArithmetic
|
||||
MLIRDialect
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
MLIRStandard
|
||||
)
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -25,9 +24,10 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
|
|||
Attribute value,
|
||||
Type type,
|
||||
Location loc) {
|
||||
// TODO complex.constant
|
||||
if (type.isa<ComplexType>())
|
||||
return builder.create<ConstantOp>(loc, type, value);
|
||||
if (complex::ConstantOp::isBuildableWith(value, type)) {
|
||||
return builder.create<complex::ConstantOp>(loc, type,
|
||||
value.cast<ArrayAttr>());
|
||||
}
|
||||
if (arith::ConstantOp::isBuildableWith(value, type))
|
||||
return builder.create<arith::ConstantOp>(loc, type, value);
|
||||
return nullptr;
|
||||
|
|
|
@ -13,11 +13,54 @@ using namespace mlir;
|
|||
using namespace mlir::complex;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
// ConstantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
|
||||
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
return getValue();
|
||||
}
|
||||
|
||||
void ConstantOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "cst");
|
||||
}
|
||||
|
||||
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
|
||||
if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
|
||||
auto complexTy = type.dyn_cast<ComplexType>();
|
||||
if (!complexTy)
|
||||
return false;
|
||||
auto complexEltTy = complexTy.getElementType();
|
||||
return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
|
||||
arrAttr[1].getType() == complexEltTy;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static LogicalResult verify(ConstantOp op) {
|
||||
ArrayAttr arrayAttr = op.getValue();
|
||||
if (arrayAttr.size() != 2) {
|
||||
return op.emitOpError(
|
||||
"requires 'value' to be a complex constant, represented as array of "
|
||||
"two values");
|
||||
}
|
||||
|
||||
auto complexEltTy = op.getType().getElementType();
|
||||
if (complexEltTy != arrayAttr[0].getType() ||
|
||||
complexEltTy != arrayAttr[1].getType()) {
|
||||
return op.emitOpError()
|
||||
<< "requires attribute's element types (" << arrayAttr[0].getType()
|
||||
<< ", " << arrayAttr[1].getType()
|
||||
<< ") to match the element type of the op's return type ("
|
||||
<< complexEltTy << ")";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CreateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "binary op takes two operands");
|
||||
|
@ -32,6 +75,10 @@ OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ImOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "unary op takes 1 operand");
|
||||
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
|
||||
|
@ -42,6 +89,10 @@ OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "unary op takes 1 operand");
|
||||
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
|
||||
|
@ -51,3 +102,10 @@ OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
|
|||
return createOp.getOperand(0);
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
|
||||
|
|
|
@ -669,8 +669,8 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
|
|||
p << ' ';
|
||||
p << op.getValue();
|
||||
|
||||
// If the value is a symbol reference or Array, print a trailing type.
|
||||
if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
|
||||
// If the value is a symbol reference, print a trailing type.
|
||||
if (op.getValue().isa<SymbolRefAttr>())
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
||||
|
@ -681,10 +681,9 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
|
|||
parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||
return failure();
|
||||
|
||||
// If the attribute is a symbol reference or array, then we expect a trailing
|
||||
// type.
|
||||
// If the attribute is a symbol reference, then we expect a trailing type.
|
||||
Type type;
|
||||
if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
|
||||
if (!valueAttr.isa<SymbolRefAttr>())
|
||||
type = valueAttr.getType();
|
||||
else if (parser.parseColonType(type))
|
||||
return failure();
|
||||
|
@ -705,24 +704,6 @@ static LogicalResult verify(ConstantOp &op) {
|
|||
return op.emitOpError() << "requires attribute's type (" << value.getType()
|
||||
<< ") to match op's return type (" << type << ")";
|
||||
|
||||
if (auto complexTy = type.dyn_cast<ComplexType>()) {
|
||||
auto arrayAttr = value.dyn_cast<ArrayAttr>();
|
||||
if (!complexTy || arrayAttr.size() != 2)
|
||||
return op.emitOpError(
|
||||
"requires 'value' to be a complex constant, represented as array of "
|
||||
"two values");
|
||||
auto complexEltTy = complexTy.getElementType();
|
||||
if (complexEltTy != arrayAttr[0].getType() ||
|
||||
complexEltTy != arrayAttr[1].getType()) {
|
||||
return op.emitOpError()
|
||||
<< "requires attribute's element types (" << arrayAttr[0].getType()
|
||||
<< ", " << arrayAttr[1].getType()
|
||||
<< ") to match the element type of the op's return type ("
|
||||
<< complexEltTy << ")";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
if (type.isa<FunctionType>()) {
|
||||
auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
|
||||
if (!fnAttr)
|
||||
|
@ -769,19 +750,8 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
|
|||
// SymbolRefAttr can only be used with a function type.
|
||||
if (value.isa<SymbolRefAttr>())
|
||||
return type.isa<FunctionType>();
|
||||
// The attribute must have the same type as 'type'.
|
||||
if (!value.getType().isa<NoneType>() && value.getType() != type)
|
||||
return false;
|
||||
// Finally, check that the attribute kind is handled.
|
||||
if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
|
||||
auto complexTy = type.dyn_cast<ComplexType>();
|
||||
if (!complexTy)
|
||||
return false;
|
||||
auto complexEltTy = complexTy.getElementType();
|
||||
return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
|
||||
arrAttr[1].getType() == complexEltTy;
|
||||
}
|
||||
return value.isa<UnitAttr>();
|
||||
// Otherwise, this must be a UnitAttr.
|
||||
return value.isa<UnitAttr>() && type.isa<NoneType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -27,7 +27,7 @@ func @create_of_real_and_imag_different_operand(
|
|||
func @real_of_const() -> f32 {
|
||||
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK-NEXT: return %[[CST]] : f32
|
||||
%complex = constant [1.0 : f32, 0.0 : f32] : complex<f32>
|
||||
%complex = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
|
||||
%1 = complex.re %complex : complex<f32>
|
||||
return %1 : f32
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ func @real_of_create_op() -> f32 {
|
|||
func @imag_of_const() -> f32 {
|
||||
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: return %[[CST]] : f32
|
||||
%complex = constant [1.0 : f32, 0.0 : f32] : complex<f32>
|
||||
%complex = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
|
||||
%1 = complex.im %complex : complex<f32>
|
||||
return %1 : f32
|
||||
}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
|
||||
|
||||
func @complex_constant_wrong_array_attribute_length() {
|
||||
// expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
|
||||
%0 = complex.constant [1.0 : f32] : complex<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @complex_constant_wrong_element_types() {
|
||||
// expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
|
||||
%0 = complex.constant [1.0 : f32, -1.0 : f32] : complex<f64>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @complex_constant_two_different_element_types() {
|
||||
// expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
|
||||
%0 = complex.constant [1.0 : f32, -1.0 : f64] : complex<f64>
|
||||
return
|
||||
}
|
|
@ -5,6 +5,12 @@
|
|||
// CHECK-LABEL: func @ops(
|
||||
// CHECK-SAME: %[[F:.*]]: f32) {
|
||||
func @ops(%f: f32) {
|
||||
// CHECK: complex.constant [1.{{.*}}, -1.{{.*}}] : complex<f64>
|
||||
%cst_f64 = complex.constant [0.1, -1.0] : complex<f64>
|
||||
|
||||
// CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32>
|
||||
%cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32>
|
||||
|
||||
// CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
|
||||
%complex = complex.create %f, %f : complex<f32>
|
||||
|
||||
|
@ -51,4 +57,3 @@ func @ops(%f: f32) {
|
|||
%diff = complex.sub %complex, %complex : complex<f32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -8,30 +8,6 @@ func @unsupported_attribute() {
|
|||
|
||||
// -----
|
||||
|
||||
func @complex_constant_wrong_array_attribute_length() {
|
||||
// expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
|
||||
%0 = constant [1.0 : f32] : complex<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @complex_constant_wrong_element_types() {
|
||||
// expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
|
||||
%0 = constant [1.0 : f32, -1.0 : f32] : complex<f64>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @complex_constant_two_different_element_types() {
|
||||
// expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
|
||||
%0 = constant [1.0 : f32, -1.0 : f64] : complex<f64>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @return_i32_f32() -> (i32, f32) {
|
||||
%0 = arith.constant 1 : i32
|
||||
%1 = arith.constant 1. : f32
|
||||
|
|
|
@ -51,18 +51,6 @@ func @switch_i64(%flag : i64, %caseOperand : i32) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @constant_complex_f32(
|
||||
func @constant_complex_f32() -> complex<f32> {
|
||||
%result = constant [0.1 : f32, -1.0 : f32] : complex<f32>
|
||||
return %result : complex<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @constant_complex_f64(
|
||||
func @constant_complex_f64() -> complex<f64> {
|
||||
%result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
|
||||
return %result : complex<f64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_splat_0d(
|
||||
func @vector_splat_0d(%a: f32) -> vector<f32> {
|
||||
// CHECK: splat %{{.*}} : vector<f32>
|
||||
|
|
Loading…
Reference in New Issue