[mlir] Move the complex support of std.constant to a new complex.constant operation

This is part of splitting up the standard dialect.

Differential Revision: https://reviews.llvm.org/D118182
This commit is contained in:
River Riddle 2022-01-25 13:20:01 -08:00
parent b88a4d72d9
commit 480cd4cb85
12 changed files with 140 additions and 86 deletions

View File

@ -10,13 +10,12 @@
#define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
//===----------------------------------------------------------------------===//
// Complex Dialect

View File

@ -19,7 +19,7 @@ def Complex_Dialect : Dialect {
arithmetic ops.
}];
let dependentDialects = ["arith::ArithmeticDialect", "StandardOpsDialect"];
let dependentDialects = ["arith::ArithmeticDialect"];
let hasConstantMaterializer = 1;
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}

View File

@ -10,6 +10,7 @@
#define COMPLEX_OPS
include "mlir/Dialect/Complex/IR/ComplexBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -76,6 +77,40 @@ def AddOp : ComplexArithmeticOp<"add"> {
}];
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
def ConstantOp : Complex_Op<"constant", [
ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "complex number constant operation";
let description = [{
The `complex.constant` operation creates a constant complex number from an
attribute containing the real and imaginary parts.
Example:
```mlir
%a = complex.constant [0.1, -1.0] : complex<f64>
```
}];
let arguments = (ins ArrayAttr:$value);
let results = (outs Complex<AnyFloat>:$complex);
let assemblyFormat = "$value attr-dict `:` type($complex)";
let hasFolder = 1;
let verifier = [{ return ::verify(*this); }];
let extraClassDeclaration = [{
/// Returns true if a constant operation can be built with the given value
/// and result type.
static bool isBuildableWith(Attribute value, Type type);
}];
}
//===----------------------------------------------------------------------===//
// CreateOp
//===----------------------------------------------------------------------===//

View File

@ -11,6 +11,6 @@ add_mlir_dialect_library(MLIRComplex
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRStandard
)

View File

@ -8,7 +8,6 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
using namespace mlir;
@ -25,9 +24,10 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
Attribute value,
Type type,
Location loc) {
// TODO complex.constant
if (type.isa<ComplexType>())
return builder.create<ConstantOp>(loc, type, value);
if (complex::ConstantOp::isBuildableWith(value, type)) {
return builder.create<complex::ConstantOp>(loc, type,
value.cast<ArrayAttr>());
}
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, type, value);
return nullptr;

View File

@ -13,11 +13,54 @@ using namespace mlir;
using namespace mlir::complex;
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
// ConstantOp
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
return getValue();
}
void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "cst");
}
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
auto complexTy = type.dyn_cast<ComplexType>();
if (!complexTy)
return false;
auto complexEltTy = complexTy.getElementType();
return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
arrAttr[1].getType() == complexEltTy;
}
return false;
}
static LogicalResult verify(ConstantOp op) {
ArrayAttr arrayAttr = op.getValue();
if (arrayAttr.size() != 2) {
return op.emitOpError(
"requires 'value' to be a complex constant, represented as array of "
"two values");
}
auto complexEltTy = op.getType().getElementType();
if (complexEltTy != arrayAttr[0].getType() ||
complexEltTy != arrayAttr[1].getType()) {
return op.emitOpError()
<< "requires attribute's element types (" << arrayAttr[0].getType()
<< ", " << arrayAttr[1].getType()
<< ") to match the element type of the op's return type ("
<< complexEltTy << ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// CreateOp
//===----------------------------------------------------------------------===//
OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary op takes two operands");
@ -32,6 +75,10 @@ OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
return {};
}
//===----------------------------------------------------------------------===//
// ImOp
//===----------------------------------------------------------------------===//
OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "unary op takes 1 operand");
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
@ -42,6 +89,10 @@ OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
return {};
}
//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//
OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "unary op takes 1 operand");
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
@ -51,3 +102,10 @@ OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
return createOp.getOperand(0);
return {};
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"

View File

@ -669,8 +669,8 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
p << ' ';
p << op.getValue();
// If the value is a symbol reference or Array, print a trailing type.
if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
// If the value is a symbol reference, print a trailing type.
if (op.getValue().isa<SymbolRefAttr>())
p << " : " << op.getType();
}
@ -681,10 +681,9 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
// If the attribute is a symbol reference or array, then we expect a trailing
// type.
// If the attribute is a symbol reference, then we expect a trailing type.
Type type;
if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
if (!valueAttr.isa<SymbolRefAttr>())
type = valueAttr.getType();
else if (parser.parseColonType(type))
return failure();
@ -705,24 +704,6 @@ static LogicalResult verify(ConstantOp &op) {
return op.emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")";
if (auto complexTy = type.dyn_cast<ComplexType>()) {
auto arrayAttr = value.dyn_cast<ArrayAttr>();
if (!complexTy || arrayAttr.size() != 2)
return op.emitOpError(
"requires 'value' to be a complex constant, represented as array of "
"two values");
auto complexEltTy = complexTy.getElementType();
if (complexEltTy != arrayAttr[0].getType() ||
complexEltTy != arrayAttr[1].getType()) {
return op.emitOpError()
<< "requires attribute's element types (" << arrayAttr[0].getType()
<< ", " << arrayAttr[1].getType()
<< ") to match the element type of the op's return type ("
<< complexEltTy << ")";
}
return success();
}
if (type.isa<FunctionType>()) {
auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
if (!fnAttr)
@ -769,19 +750,8 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
// SymbolRefAttr can only be used with a function type.
if (value.isa<SymbolRefAttr>())
return type.isa<FunctionType>();
// The attribute must have the same type as 'type'.
if (!value.getType().isa<NoneType>() && value.getType() != type)
return false;
// Finally, check that the attribute kind is handled.
if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
auto complexTy = type.dyn_cast<ComplexType>();
if (!complexTy)
return false;
auto complexEltTy = complexTy.getElementType();
return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
arrAttr[1].getType() == complexEltTy;
}
return value.isa<UnitAttr>();
// Otherwise, this must be a UnitAttr.
return value.isa<UnitAttr>() && type.isa<NoneType>();
}
//===----------------------------------------------------------------------===//

View File

@ -27,7 +27,7 @@ func @create_of_real_and_imag_different_operand(
func @real_of_const() -> f32 {
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: return %[[CST]] : f32
%complex = constant [1.0 : f32, 0.0 : f32] : complex<f32>
%complex = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
%1 = complex.re %complex : complex<f32>
return %1 : f32
}
@ -47,7 +47,7 @@ func @real_of_create_op() -> f32 {
func @imag_of_const() -> f32 {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: return %[[CST]] : f32
%complex = constant [1.0 : f32, 0.0 : f32] : complex<f32>
%complex = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
%1 = complex.im %complex : complex<f32>
return %1 : f32
}

View File

@ -0,0 +1,23 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
func @complex_constant_wrong_array_attribute_length() {
// expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
%0 = complex.constant [1.0 : f32] : complex<f32>
return
}
// -----
func @complex_constant_wrong_element_types() {
// expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
%0 = complex.constant [1.0 : f32, -1.0 : f32] : complex<f64>
return
}
// -----
func @complex_constant_two_different_element_types() {
// expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
%0 = complex.constant [1.0 : f32, -1.0 : f64] : complex<f64>
return
}

View File

@ -5,6 +5,12 @@
// CHECK-LABEL: func @ops(
// CHECK-SAME: %[[F:.*]]: f32) {
func @ops(%f: f32) {
// CHECK: complex.constant [1.{{.*}}, -1.{{.*}}] : complex<f64>
%cst_f64 = complex.constant [0.1, -1.0] : complex<f64>
// CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32>
%cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32>
// CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
%complex = complex.create %f, %f : complex<f32>
@ -51,4 +57,3 @@ func @ops(%f: f32) {
%diff = complex.sub %complex, %complex : complex<f32>
return
}

View File

@ -8,30 +8,6 @@ func @unsupported_attribute() {
// -----
func @complex_constant_wrong_array_attribute_length() {
// expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
%0 = constant [1.0 : f32] : complex<f32>
return
}
// -----
func @complex_constant_wrong_element_types() {
// expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
%0 = constant [1.0 : f32, -1.0 : f32] : complex<f64>
return
}
// -----
func @complex_constant_two_different_element_types() {
// expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
%0 = constant [1.0 : f32, -1.0 : f64] : complex<f64>
return
}
// -----
func @return_i32_f32() -> (i32, f32) {
%0 = arith.constant 1 : i32
%1 = arith.constant 1. : f32

View File

@ -51,18 +51,6 @@ func @switch_i64(%flag : i64, %caseOperand : i32) {
return
}
// CHECK-LABEL: func @constant_complex_f32(
func @constant_complex_f32() -> complex<f32> {
%result = constant [0.1 : f32, -1.0 : f32] : complex<f32>
return %result : complex<f32>
}
// CHECK-LABEL: func @constant_complex_f64(
func @constant_complex_f64() -> complex<f64> {
%result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
return %result : complex<f64>
}
// CHECK-LABEL: func @vector_splat_0d(
func @vector_splat_0d(%a: f32) -> vector<f32> {
// CHECK: splat %{{.*}} : vector<f32>