forked from OSchip/llvm-project
[mlir] Add Complex dialect.
Differential Revision: https://reviews.llvm.org/D94764
This commit is contained in:
parent
cc90d41945
commit
d0cb0d30a4
|
@ -0,0 +1,29 @@
|
||||||
|
//===- ComplexToLLVM.h - Utils to convert from the complex dialect --------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
#ifndef MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_
|
||||||
|
#define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_
|
||||||
|
|
||||||
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class MLIRContext;
|
||||||
|
class ModuleOp;
|
||||||
|
template <typename T>
|
||||||
|
class OperationPass;
|
||||||
|
|
||||||
|
/// Populate the given list with patterns that convert from Complex to LLVM.
|
||||||
|
void populateComplexToLLVMConversionPatterns(
|
||||||
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
|
||||||
|
|
||||||
|
/// Create a pass to convert Complex operations to the LLVMIR dialect.
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertComplexToLLVMPass();
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_
|
|
@ -11,6 +11,7 @@
|
||||||
|
|
||||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||||
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
|
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
|
||||||
|
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
|
||||||
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
||||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
||||||
|
|
|
@ -88,6 +88,16 @@ def ConvertAsyncToLLVM : Pass<"convert-async-to-llvm", "ModuleOp"> {
|
||||||
let dependentDialects = ["LLVM::LLVMDialect"];
|
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ComplexToLLVM
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ConvertComplexToLLVM : Pass<"convert-complex-to-llvm", "ModuleOp"> {
|
||||||
|
let summary = "Convert Complex dialect to LLVM dialect";
|
||||||
|
let constructor = "mlir::createConvertComplexToLLVMPass()";
|
||||||
|
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GPUCommon
|
// GPUCommon
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -3,6 +3,7 @@ add_subdirectory(Async)
|
||||||
add_subdirectory(ArmNeon)
|
add_subdirectory(ArmNeon)
|
||||||
add_subdirectory(ArmSVE)
|
add_subdirectory(ArmSVE)
|
||||||
add_subdirectory(AVX512)
|
add_subdirectory(AVX512)
|
||||||
|
add_subdirectory(Complex)
|
||||||
add_subdirectory(GPU)
|
add_subdirectory(GPU)
|
||||||
add_subdirectory(Linalg)
|
add_subdirectory(Linalg)
|
||||||
add_subdirectory(LLVMIR)
|
add_subdirectory(LLVMIR)
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(IR)
|
|
@ -0,0 +1,2 @@
|
||||||
|
add_mlir_dialect(ComplexOps complex)
|
||||||
|
add_mlir_doc(ComplexOps -gen-dialect-doc ComplexOps Dialects/)
|
|
@ -0,0 +1,32 @@
|
||||||
|
//===- Complex.h - Complex dialect --------------------------------*- C++-*-==//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_
|
||||||
|
#define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Complex Dialect
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Complex/IR/ComplexOpsDialect.h.inc"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Complex Dialect Operations
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "mlir/Dialect/Complex/IR/ComplexOps.h.inc"
|
||||||
|
|
||||||
|
#endif // MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_
|
|
@ -0,0 +1,23 @@
|
||||||
|
//===- ComplexBase.td - Base definitions for complex dialect -*- tablegen -*-=//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef COMPLEX_BASE
|
||||||
|
#define COMPLEX_BASE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def Complex_Dialect : Dialect {
|
||||||
|
let name = "complex";
|
||||||
|
let cppNamespace = "::mlir::complex";
|
||||||
|
let description = [{
|
||||||
|
The complex dialect is intended to hold complex numbers creation and
|
||||||
|
arithmetic ops.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // COMPLEX_BASE
|
|
@ -0,0 +1,153 @@
|
||||||
|
//===- ComplexOps.td - Complex op definitions ----------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef COMPLEX_OPS
|
||||||
|
#define COMPLEX_OPS
|
||||||
|
|
||||||
|
include "mlir/Dialect/Complex/IR/ComplexBase.td"
|
||||||
|
include "mlir/Interfaces/VectorInterfaces.td"
|
||||||
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
|
||||||
|
class Complex_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
|
: Op<Complex_Dialect, mnemonic, traits>;
|
||||||
|
|
||||||
|
// Base class for standard arithmetic operations on complex numbers with a
|
||||||
|
// floating-point element type. These operations take two operands and return
|
||||||
|
// one result, all of which must be complex numbers of the same type.
|
||||||
|
class ComplexArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
Complex_Op<mnemonic,
|
||||||
|
!listconcat(traits, [NoSideEffect,
|
||||||
|
SameOperandsAndResultType,
|
||||||
|
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
|
||||||
|
ElementwiseMappable])> {
|
||||||
|
let arguments = (ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs);
|
||||||
|
let results = (outs Complex<AnyFloat>:$result);
|
||||||
|
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
|
||||||
|
let verifier = ?;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AddOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def AddOp : ComplexArithmeticOp<"add"> {
|
||||||
|
let summary = "complex addition";
|
||||||
|
let description = [{
|
||||||
|
The `add` operation takes two complex numbers and returns their sum.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%a = add %b, %c : complex<f32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// CreateOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def CreateOp : Complex_Op<"create",
|
||||||
|
[NoSideEffect,
|
||||||
|
AllTypesMatch<["real", "imaginary"]>,
|
||||||
|
TypesMatchWith<"complex element type matches real operand type",
|
||||||
|
"complex", "real",
|
||||||
|
"$_self.cast<ComplexType>().getElementType()">,
|
||||||
|
TypesMatchWith<"complex element type matches imaginary operand type",
|
||||||
|
"complex", "imaginary",
|
||||||
|
"$_self.cast<ComplexType>().getElementType()">]> {
|
||||||
|
|
||||||
|
let summary = "complex number creation operation";
|
||||||
|
let description = [{
|
||||||
|
The `complex.complex` operation creates a complex number from two
|
||||||
|
floating-point operands, the real and the imaginary part.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%a = create_complex %b, %c : complex<f32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins AnyFloat:$real, AnyFloat:$imaginary);
|
||||||
|
let results = (outs Complex<AnyFloat>:$complex);
|
||||||
|
|
||||||
|
let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ImOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ImOp : Complex_Op<"im",
|
||||||
|
[NoSideEffect,
|
||||||
|
TypesMatchWith<"complex element type matches result type",
|
||||||
|
"complex", "imaginary",
|
||||||
|
"$_self.cast<ComplexType>().getElementType()">]> {
|
||||||
|
let summary = "extracts the imaginary part of a complex number";
|
||||||
|
let description = [{
|
||||||
|
The `im` op takes a single complex number and extracts the imaginary part.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%a = im %b : complex<f32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Complex<AnyFloat>:$complex);
|
||||||
|
let results = (outs AnyFloat:$imaginary);
|
||||||
|
|
||||||
|
let assemblyFormat = "$complex attr-dict `:` type($complex)";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ReOp : Complex_Op<"re",
|
||||||
|
[NoSideEffect,
|
||||||
|
TypesMatchWith<"complex element type matches result type",
|
||||||
|
"complex", "real",
|
||||||
|
"$_self.cast<ComplexType>().getElementType()">]> {
|
||||||
|
let summary = "extracts the real part of a complex number";
|
||||||
|
let description = [{
|
||||||
|
The `re` op takes a single complex number and extracts the real part.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%a = re %b : complex<f32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Complex<AnyFloat>:$complex);
|
||||||
|
let results = (outs AnyFloat:$real);
|
||||||
|
|
||||||
|
let assemblyFormat = "$complex attr-dict `:` type($complex)";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SubOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def SubOp : ComplexArithmeticOp<"sub"> {
|
||||||
|
let summary = "complex subtraction";
|
||||||
|
let description = [{
|
||||||
|
The `sub` operation takes two complex numbers and returns their difference.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%a = sub %b, %c : complex<f32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // COMPLEX_OPS
|
|
@ -19,6 +19,7 @@
|
||||||
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
||||||
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
||||||
#include "mlir/Dialect/Async/IR/Async.h"
|
#include "mlir/Dialect/Async/IR/Async.h"
|
||||||
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
|
||||||
|
@ -52,6 +53,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
||||||
arm_neon::ArmNeonDialect,
|
arm_neon::ArmNeonDialect,
|
||||||
async::AsyncDialect,
|
async::AsyncDialect,
|
||||||
avx512::AVX512Dialect,
|
avx512::AVX512Dialect,
|
||||||
|
complex::ComplexDialect,
|
||||||
gpu::GPUDialect,
|
gpu::GPUDialect,
|
||||||
LLVM::LLVMAVX512Dialect,
|
LLVM::LLVMAVX512Dialect,
|
||||||
LLVM::LLVMDialect,
|
LLVM::LLVMDialect,
|
||||||
|
|
|
@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
|
||||||
add_subdirectory(ArmNeonToLLVM)
|
add_subdirectory(ArmNeonToLLVM)
|
||||||
add_subdirectory(AsyncToLLVM)
|
add_subdirectory(AsyncToLLVM)
|
||||||
add_subdirectory(AVX512ToLLVM)
|
add_subdirectory(AVX512ToLLVM)
|
||||||
|
add_subdirectory(ComplexToLLVM)
|
||||||
add_subdirectory(GPUCommon)
|
add_subdirectory(GPUCommon)
|
||||||
add_subdirectory(GPUToNVVM)
|
add_subdirectory(GPUToNVVM)
|
||||||
add_subdirectory(GPUToROCDL)
|
add_subdirectory(GPUToROCDL)
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
add_mlir_conversion_library(MLIRComplexToLLVM
|
||||||
|
ComplexToLLVM.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToLLVM
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
MLIRConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRComplex
|
||||||
|
MLIRLLVMIR
|
||||||
|
MLIRStandardOpsTransforms
|
||||||
|
MLIRStandardToLLVM
|
||||||
|
MLIRTransforms
|
||||||
|
)
|
|
@ -0,0 +1,193 @@
|
||||||
|
//===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::LLVM;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
|
||||||
|
using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(complex::CreateOp complexOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
complex::CreateOp::Adaptor transformed(operands);
|
||||||
|
|
||||||
|
// Pack real and imaginary part in a complex number struct.
|
||||||
|
auto loc = complexOp.getLoc();
|
||||||
|
auto structType = typeConverter->convertType(complexOp.getType());
|
||||||
|
auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
|
||||||
|
complexStruct.setReal(rewriter, loc, transformed.real());
|
||||||
|
complexStruct.setImaginary(rewriter, loc, transformed.imaginary());
|
||||||
|
|
||||||
|
rewriter.replaceOp(complexOp, {complexStruct});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
|
||||||
|
using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(complex::ReOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
complex::ReOp::Adaptor transformed(operands);
|
||||||
|
|
||||||
|
// Extract real part from the complex number struct.
|
||||||
|
ComplexStructBuilder complexStruct(transformed.complex());
|
||||||
|
Value real = complexStruct.real(rewriter, op.getLoc());
|
||||||
|
rewriter.replaceOp(op, real);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
|
||||||
|
using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(complex::ImOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
complex::ImOp::Adaptor transformed(operands);
|
||||||
|
|
||||||
|
// Extract imaginary part from the complex number struct.
|
||||||
|
ComplexStructBuilder complexStruct(transformed.complex());
|
||||||
|
Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
|
||||||
|
rewriter.replaceOp(op, imaginary);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BinaryComplexOperands {
|
||||||
|
std::complex<Value> lhs;
|
||||||
|
std::complex<Value> rhs;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
BinaryComplexOperands
|
||||||
|
unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
typename OpTy::Adaptor transformed(operands);
|
||||||
|
|
||||||
|
// Extract real and imaginary values from operands.
|
||||||
|
BinaryComplexOperands unpacked;
|
||||||
|
ComplexStructBuilder lhs(transformed.lhs());
|
||||||
|
unpacked.lhs.real(lhs.real(rewriter, loc));
|
||||||
|
unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
|
||||||
|
ComplexStructBuilder rhs(transformed.rhs());
|
||||||
|
unpacked.rhs.real(rhs.real(rewriter, loc));
|
||||||
|
unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
|
||||||
|
|
||||||
|
return unpacked;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
|
||||||
|
using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(complex::AddOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
BinaryComplexOperands arg =
|
||||||
|
unpackBinaryComplexOperands<complex::AddOp>(op, operands, rewriter);
|
||||||
|
|
||||||
|
// Initialize complex number struct for result.
|
||||||
|
auto structType = typeConverter->convertType(op.getType());
|
||||||
|
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
|
||||||
|
|
||||||
|
// Emit IR to add complex numbers.
|
||||||
|
auto fmf = LLVM::FMFAttr::get({}, op.getContext());
|
||||||
|
Value real =
|
||||||
|
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
|
||||||
|
Value imag =
|
||||||
|
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
|
||||||
|
result.setReal(rewriter, loc, real);
|
||||||
|
result.setImaginary(rewriter, loc, imag);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {result});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
|
||||||
|
using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(complex::SubOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
BinaryComplexOperands arg =
|
||||||
|
unpackBinaryComplexOperands<complex::SubOp>(op, operands, rewriter);
|
||||||
|
|
||||||
|
// Initialize complex number struct for result.
|
||||||
|
auto structType = typeConverter->convertType(op.getType());
|
||||||
|
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
|
||||||
|
|
||||||
|
// Emit IR to substract complex numbers.
|
||||||
|
auto fmf = LLVM::FMFAttr::get({}, op.getContext());
|
||||||
|
Value real =
|
||||||
|
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
|
||||||
|
Value imag =
|
||||||
|
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
|
||||||
|
result.setReal(rewriter, loc, real);
|
||||||
|
result.setImaginary(rewriter, loc, imag);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {result});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void mlir::populateComplexToLLVMConversionPatterns(
|
||||||
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||||
|
// clang-format off
|
||||||
|
patterns.insert<
|
||||||
|
AddOpConversion,
|
||||||
|
CreateOpConversion,
|
||||||
|
ImOpConversion,
|
||||||
|
ReOpConversion,
|
||||||
|
SubOpConversion
|
||||||
|
>(converter);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct ConvertComplexToLLVMPass
|
||||||
|
: public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
|
||||||
|
void runOnOperation() override;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ConvertComplexToLLVMPass::runOnOperation() {
|
||||||
|
auto module = getOperation();
|
||||||
|
|
||||||
|
// Convert to the LLVM IR dialect using the converter defined above.
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
LLVMTypeConverter converter(&getContext());
|
||||||
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||||
|
populateComplexToLLVMConversionPatterns(converter, patterns);
|
||||||
|
|
||||||
|
LLVMConversionTarget target(getContext());
|
||||||
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
|
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::createConvertComplexToLLVMPass() {
|
||||||
|
return std::make_unique<ConvertComplexToLLVMPass>();
|
||||||
|
}
|
|
@ -19,6 +19,10 @@ class StandardOpsDialect;
|
||||||
template <typename ConcreteDialect>
|
template <typename ConcreteDialect>
|
||||||
void registerDialect(DialectRegistry ®istry);
|
void registerDialect(DialectRegistry ®istry);
|
||||||
|
|
||||||
|
namespace complex {
|
||||||
|
class ComplexDialect;
|
||||||
|
} // end namespace complex
|
||||||
|
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
class GPUDialect;
|
class GPUDialect;
|
||||||
class GPUModuleOp;
|
class GPUModuleOp;
|
||||||
|
|
|
@ -3,6 +3,7 @@ add_subdirectory(ArmNeon)
|
||||||
add_subdirectory(ArmSVE)
|
add_subdirectory(ArmSVE)
|
||||||
add_subdirectory(Async)
|
add_subdirectory(Async)
|
||||||
add_subdirectory(AVX512)
|
add_subdirectory(AVX512)
|
||||||
|
add_subdirectory(Complex)
|
||||||
add_subdirectory(GPU)
|
add_subdirectory(GPU)
|
||||||
add_subdirectory(Linalg)
|
add_subdirectory(Linalg)
|
||||||
add_subdirectory(LLVMIR)
|
add_subdirectory(LLVMIR)
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(IR)
|
|
@ -0,0 +1,14 @@
|
||||||
|
add_mlir_dialect_library(MLIRComplex
|
||||||
|
ComplexOps.cpp
|
||||||
|
ComplexDialect.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Complex
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
MLIRComplexOpsIncGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRDialect
|
||||||
|
MLIRIR
|
||||||
|
)
|
|
@ -0,0 +1,16 @@
|
||||||
|
//===- ComplexDialect.cpp - MLIR Complex Dialect --------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||||
|
|
||||||
|
void mlir::complex::ComplexDialect::initialize() {
|
||||||
|
addOperations<
|
||||||
|
#define GET_OP_LIST
|
||||||
|
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
|
||||||
|
>();
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
//===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::complex;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TableGen'd op method definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
|
|
@ -0,0 +1,61 @@
|
||||||
|
// RUN: mlir-opt %s -split-input-file -convert-complex-to-llvm | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @complex_numbers()
|
||||||
|
// CHECK-NEXT: %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : f32
|
||||||
|
// CHECK-NEXT: %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : f32
|
||||||
|
// CHECK-NEXT: %[[CPLX0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)>
|
||||||
|
// CHECK-NEXT: %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm.struct<(f32, f32)>
|
||||||
|
// CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm.struct<(f32, f32)>
|
||||||
|
// CHECK-NEXT: %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm.struct<(f32, f32)>
|
||||||
|
// CHECK-NEXT: %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm.struct<(f32, f32)>
|
||||||
|
// CHECK-NEXT: llvm.return
|
||||||
|
func @complex_numbers() {
|
||||||
|
%real0 = constant 1.2 : f32
|
||||||
|
%imag0 = constant 3.4 : f32
|
||||||
|
%cplx2 = complex.create %real0, %imag0 : complex<f32>
|
||||||
|
%real1 = complex.re%cplx2 : complex<f32>
|
||||||
|
%imag1 = complex.im %cplx2 : complex<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @complex_addition()
|
||||||
|
// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK-DAG: %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] : f64
|
||||||
|
// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] : f64
|
||||||
|
// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)>
|
||||||
|
func @complex_addition() {
|
||||||
|
%a_re = constant 1.2 : f64
|
||||||
|
%a_im = constant 3.4 : f64
|
||||||
|
%a = complex.create %a_re, %a_im : complex<f64>
|
||||||
|
%b_re = constant 5.6 : f64
|
||||||
|
%b_im = constant 7.8 : f64
|
||||||
|
%b = complex.create %b_re, %b_im : complex<f64>
|
||||||
|
%c = complex.add %a, %b : complex<f64>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: llvm.func @complex_substraction()
|
||||||
|
// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK-DAG: %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] : f64
|
||||||
|
// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] : f64
|
||||||
|
// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)>
|
||||||
|
// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)>
|
||||||
|
func @complex_substraction() {
|
||||||
|
%a_re = constant 1.2 : f64
|
||||||
|
%a_im = constant 3.4 : f64
|
||||||
|
%a = complex.create %a_re, %a_im : complex<f64>
|
||||||
|
%b_re = constant 5.6 : f64
|
||||||
|
%b_im = constant 7.8 : f64
|
||||||
|
%b = complex.create %b_re, %b_im : complex<f64>
|
||||||
|
%c = complex.sub %a, %b : complex<f64>
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
||||||
|
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @ops(
|
||||||
|
// CHECK-SAME: [[F:%.*]]: f32) {
|
||||||
|
func @ops(%f: f32) {
|
||||||
|
// CHECK: [[C:%.*]] = complex.create [[F]], [[F]] : complex<f32>
|
||||||
|
%complex = complex.create %f, %f : complex<f32>
|
||||||
|
|
||||||
|
// CHECK: complex.re [[C]] : complex<f32>
|
||||||
|
%real = complex.re %complex : complex<f32>
|
||||||
|
|
||||||
|
// CHECK: complex.im [[C]] : complex<f32>
|
||||||
|
%imag = complex.im %complex : complex<f32>
|
||||||
|
|
||||||
|
// CHECK: complex.add [[C]], [[C]] : complex<f32>
|
||||||
|
%sum = complex.add %complex, %complex : complex<f32>
|
||||||
|
|
||||||
|
// CHECK: complex.sub [[C]], [[C]] : complex<f32>
|
||||||
|
%diff = complex.sub %complex, %complex : complex<f32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
// CHECK-NEXT: arm_sve
|
// CHECK-NEXT: arm_sve
|
||||||
// CHECK-NEXT: async
|
// CHECK-NEXT: async
|
||||||
// CHECK-NEXT: avx512
|
// CHECK-NEXT: avx512
|
||||||
|
// CHECK-NEXT: complex
|
||||||
// CHECK-NEXT: gpu
|
// CHECK-NEXT: gpu
|
||||||
// CHECK-NEXT: linalg
|
// CHECK-NEXT: linalg
|
||||||
// CHECK-NEXT: llvm
|
// CHECK-NEXT: llvm
|
||||||
|
|
Loading…
Reference in New Issue