forked from OSchip/llvm-project
[Flang] Add a factory class for creating Complex Ops
Use the factory class in the FIRBuilder. Add unit tests for the factory class function and the convert function of the Complex class. Reviewed By: clementval, rovka Differential Revision: https://reviews.llvm.org/D114125 Co-authored-by: Jean Perier <jperier@nvidia.com> Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
parent
45e102a173
commit
a1f9bd32c5
|
@ -0,0 +1,89 @@
|
|||
//===-- Complex.h -- lowering of complex values -----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H
|
||||
#define FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H
|
||||
|
||||
#include "flang/Optimizer/Builder/FIRBuilder.h"
|
||||
|
||||
namespace fir::factory {
|
||||
|
||||
/// Helper to facilitate lowering of COMPLEX manipulations in FIR.
|
||||
class Complex {
|
||||
public:
|
||||
explicit Complex(FirOpBuilder &builder, mlir::Location loc)
|
||||
: builder(builder), loc(loc) {}
|
||||
Complex(const Complex &) = delete;
|
||||
|
||||
// The values of part enum members are meaningful for
|
||||
// InsertValueOp and ExtractValueOp so they are explicit.
|
||||
enum class Part { Real = 0, Imag = 1 };
|
||||
|
||||
/// Get the Complex Type. Determine the type. Do not create MLIR operations.
|
||||
mlir::Type getComplexPartType(mlir::Value cplx) const;
|
||||
mlir::Type getComplexPartType(mlir::Type complexType) const;
|
||||
|
||||
/// Complex operation creation. They create MLIR operations.
|
||||
mlir::Value createComplex(fir::KindTy kind, mlir::Value real,
|
||||
mlir::Value imag);
|
||||
|
||||
/// Create a complex value.
|
||||
mlir::Value createComplex(mlir::Type complexType, mlir::Value real,
|
||||
mlir::Value imag);
|
||||
|
||||
/// Returns the Real/Imag part of \p cplx
|
||||
mlir::Value extractComplexPart(mlir::Value cplx, bool isImagPart) {
|
||||
return isImagPart ? extract<Part::Imag>(cplx) : extract<Part::Real>(cplx);
|
||||
}
|
||||
|
||||
/// Returns (Real, Imag) pair of \p cplx
|
||||
std::pair<mlir::Value, mlir::Value> extractParts(mlir::Value cplx) {
|
||||
return {extract<Part::Real>(cplx), extract<Part::Imag>(cplx)};
|
||||
}
|
||||
|
||||
mlir::Value insertComplexPart(mlir::Value cplx, mlir::Value part,
|
||||
bool isImagPart) {
|
||||
return isImagPart ? insert<Part::Imag>(cplx, part)
|
||||
: insert<Part::Real>(cplx, part);
|
||||
}
|
||||
|
||||
protected:
|
||||
template <Part partId>
|
||||
mlir::Value extract(mlir::Value cplx) {
|
||||
return builder.create<fir::ExtractValueOp>(
|
||||
loc, getComplexPartType(cplx), cplx,
|
||||
builder.getArrayAttr({builder.getIntegerAttr(
|
||||
builder.getIndexType(), static_cast<int>(partId))}));
|
||||
}
|
||||
|
||||
template <Part partId>
|
||||
mlir::Value insert(mlir::Value cplx, mlir::Value part) {
|
||||
return builder.create<fir::InsertValueOp>(
|
||||
loc, cplx.getType(), cplx, part,
|
||||
builder.getArrayAttr({builder.getIntegerAttr(
|
||||
builder.getIndexType(), static_cast<int>(partId))}));
|
||||
}
|
||||
|
||||
template <Part partId>
|
||||
mlir::Value createPartId() {
|
||||
return builder.createIntegerConstant(loc, builder.getIndexType(),
|
||||
static_cast<int>(partId));
|
||||
}
|
||||
|
||||
private:
|
||||
FirOpBuilder &builder;
|
||||
mlir::Location loc;
|
||||
};
|
||||
|
||||
} // namespace fir::factory
|
||||
|
||||
#endif // FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H
|
|
@ -57,6 +57,15 @@ public:
|
|||
/// Get a reference to the kind map.
|
||||
const fir::KindMapping &getKindMap() { return kindMap; }
|
||||
|
||||
/// The LHS and RHS are not always in agreement in terms of
|
||||
/// type. In some cases, the disagreement is between COMPLEX and other scalar
|
||||
/// types. In that case, the conversion must insert/extract out of a COMPLEX
|
||||
/// value to have the proper semantics and be strongly typed. For e.g for
|
||||
/// converting an integer/real to a complex, the real part is filled using
|
||||
/// the integer/real after type conversion and the imaginary part is zero.
|
||||
mlir::Value convertWithSemantics(mlir::Location loc, mlir::Type toTy,
|
||||
mlir::Value val);
|
||||
|
||||
/// Get the entry block of the current Function
|
||||
mlir::Block *getEntryBlock() { return &getFunction().front(); }
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
|||
add_flang_library(FIRBuilder
|
||||
BoxValue.cpp
|
||||
Character.cpp
|
||||
Complex.cpp
|
||||
DoLoopHelper.cpp
|
||||
FIRBuilder.cpp
|
||||
MutableBox.cpp
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
//===-- Complex.cpp -------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "flang/Optimizer/Builder/Complex.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Complex Factory implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
mlir::Type
|
||||
fir::factory::Complex::getComplexPartType(mlir::Type complexType) const {
|
||||
return builder.getRealType(complexType.cast<fir::ComplexType>().getFKind());
|
||||
}
|
||||
|
||||
mlir::Type fir::factory::Complex::getComplexPartType(mlir::Value cplx) const {
|
||||
return getComplexPartType(cplx.getType());
|
||||
}
|
||||
|
||||
mlir::Value fir::factory::Complex::createComplex(fir::KindTy kind,
|
||||
mlir::Value real,
|
||||
mlir::Value imag) {
|
||||
auto complexTy = fir::ComplexType::get(builder.getContext(), kind);
|
||||
return createComplex(complexTy, real, imag);
|
||||
}
|
||||
|
||||
mlir::Value fir::factory::Complex::createComplex(mlir::Type cplxTy,
|
||||
mlir::Value real,
|
||||
mlir::Value imag) {
|
||||
mlir::Value und = builder.create<fir::UndefOp>(loc, cplxTy);
|
||||
return insert<Part::Imag>(insert<Part::Real>(und, real), imag);
|
||||
}
|
|
@ -9,6 +9,7 @@
|
|||
#include "flang/Optimizer/Builder/FIRBuilder.h"
|
||||
#include "flang/Optimizer/Builder/BoxValue.h"
|
||||
#include "flang/Optimizer/Builder/Character.h"
|
||||
#include "flang/Optimizer/Builder/Complex.h"
|
||||
#include "flang/Optimizer/Builder/MutableBox.h"
|
||||
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
|
||||
#include "flang/Optimizer/Support/FatalError.h"
|
||||
|
@ -257,6 +258,33 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
|
|||
return glob;
|
||||
}
|
||||
|
||||
mlir::Value fir::FirOpBuilder::convertWithSemantics(mlir::Location loc,
|
||||
mlir::Type toTy,
|
||||
mlir::Value val) {
|
||||
assert(toTy && "store location must be typed");
|
||||
auto fromTy = val.getType();
|
||||
if (fromTy == toTy)
|
||||
return val;
|
||||
fir::factory::Complex helper{*this, loc};
|
||||
if ((fir::isa_real(fromTy) || fir::isa_integer(fromTy)) &&
|
||||
fir::isa_complex(toTy)) {
|
||||
// imaginary part is zero
|
||||
auto eleTy = helper.getComplexPartType(toTy);
|
||||
auto cast = createConvert(loc, eleTy, val);
|
||||
llvm::APFloat zero{
|
||||
kindMap.getFloatSemantics(toTy.cast<fir::ComplexType>().getFKind()), 0};
|
||||
auto imag = createRealConstant(loc, eleTy, zero);
|
||||
return helper.createComplex(toTy, cast, imag);
|
||||
}
|
||||
if (fir::isa_complex(fromTy) &&
|
||||
(fir::isa_integer(toTy) || fir::isa_real(toTy))) {
|
||||
// drop the imaginary part
|
||||
auto rp = helper.extractComplexPart(val, /*isImagPart=*/false);
|
||||
return createConvert(loc, toTy, rp);
|
||||
}
|
||||
return createConvert(loc, toTy, val);
|
||||
}
|
||||
|
||||
mlir::Value fir::FirOpBuilder::createConvert(mlir::Location loc,
|
||||
mlir::Type toTy, mlir::Value val) {
|
||||
if (val.getType() != toTy) {
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
//===- ComplexExprTest.cpp -- ComplexExpr unit tests ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "flang/Optimizer/Builder/Complex.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "flang/Optimizer/Builder/FIRBuilder.h"
|
||||
#include "flang/Optimizer/Support/InitFIR.h"
|
||||
#include "flang/Optimizer/Support/KindMapping.h"
|
||||
|
||||
struct ComplexTest : public testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
mlir::OpBuilder builder(&context);
|
||||
auto loc = builder.getUnknownLoc();
|
||||
|
||||
// Set up a Module with a dummy function operation inside.
|
||||
// Set the insertion point in the function entry block.
|
||||
mlir::ModuleOp mod = builder.create<mlir::ModuleOp>(loc);
|
||||
mlir::FuncOp func = mlir::FuncOp::create(
|
||||
loc, "func1", builder.getFunctionType(llvm::None, llvm::None));
|
||||
auto *entryBlock = func.addEntryBlock();
|
||||
mod.push_back(mod);
|
||||
builder.setInsertionPointToStart(entryBlock);
|
||||
|
||||
fir::support::loadDialects(context);
|
||||
kindMap = std::make_unique<fir::KindMapping>(&context);
|
||||
firBuilder = std::make_unique<fir::FirOpBuilder>(mod, *kindMap);
|
||||
helper = std::make_unique<fir::factory::Complex>(*firBuilder, loc);
|
||||
|
||||
// Init commonly used types
|
||||
realTy1 = mlir::FloatType::getF32(&context);
|
||||
complexTy1 = fir::ComplexType::get(&context, 4);
|
||||
integerTy1 = mlir::IntegerType::get(&context, 32);
|
||||
|
||||
// Create commonly used reals
|
||||
rOne = firBuilder->createRealConstant(loc, realTy1, 1u);
|
||||
rTwo = firBuilder->createRealConstant(loc, realTy1, 2u);
|
||||
rThree = firBuilder->createRealConstant(loc, realTy1, 3u);
|
||||
rFour = firBuilder->createRealConstant(loc, realTy1, 4u);
|
||||
}
|
||||
|
||||
mlir::MLIRContext context;
|
||||
std::unique_ptr<fir::KindMapping> kindMap;
|
||||
std::unique_ptr<fir::FirOpBuilder> firBuilder;
|
||||
std::unique_ptr<fir::factory::Complex> helper;
|
||||
|
||||
// Commonly used real/complex/integer types
|
||||
mlir::FloatType realTy1;
|
||||
fir::ComplexType complexTy1;
|
||||
mlir::IntegerType integerTy1;
|
||||
|
||||
// Commonly used real numbers
|
||||
mlir::Value rOne;
|
||||
mlir::Value rTwo;
|
||||
mlir::Value rThree;
|
||||
mlir::Value rFour;
|
||||
};
|
||||
|
||||
TEST_F(ComplexTest, verifyTypes) {
|
||||
mlir::Value cVal1 = helper->createComplex(complexTy1, rOne, rTwo);
|
||||
mlir::Value cVal2 = helper->createComplex(4, rOne, rTwo);
|
||||
EXPECT_TRUE(fir::isa_complex(cVal1.getType()));
|
||||
EXPECT_TRUE(fir::isa_complex(cVal2.getType()));
|
||||
EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal1)));
|
||||
EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal2)));
|
||||
|
||||
mlir::Value real1 = helper->extractComplexPart(cVal1, /*isImagPart=*/false);
|
||||
mlir::Value imag1 = helper->extractComplexPart(cVal1, /*isImagPart=*/true);
|
||||
mlir::Value real2 = helper->extractComplexPart(cVal2, /*isImagPart=*/false);
|
||||
mlir::Value imag2 = helper->extractComplexPart(cVal2, /*isImagPart=*/true);
|
||||
EXPECT_EQ(realTy1, real1.getType());
|
||||
EXPECT_EQ(realTy1, imag1.getType());
|
||||
EXPECT_EQ(realTy1, real2.getType());
|
||||
EXPECT_EQ(realTy1, imag2.getType());
|
||||
|
||||
mlir::Value cVal3 =
|
||||
helper->insertComplexPart(cVal1, rThree, /*isImagPart=*/false);
|
||||
mlir::Value cVal4 =
|
||||
helper->insertComplexPart(cVal3, rFour, /*isImagPart=*/true);
|
||||
EXPECT_TRUE(fir::isa_complex(cVal4.getType()));
|
||||
EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal4)));
|
||||
}
|
||||
|
||||
TEST_F(ComplexTest, verifyConvertWithSemantics) {
|
||||
auto loc = firBuilder->getUnknownLoc();
|
||||
rOne = firBuilder->createRealConstant(loc, realTy1, 1u);
|
||||
// Convert real to complex
|
||||
mlir::Value v1 = firBuilder->convertWithSemantics(loc, complexTy1, rOne);
|
||||
EXPECT_TRUE(fir::isa_complex(v1.getType()));
|
||||
|
||||
// Convert complex to integer
|
||||
mlir::Value v2 = firBuilder->convertWithSemantics(loc, integerTy1, v1);
|
||||
EXPECT_TRUE(v2.getType().isa<mlir::IntegerType>());
|
||||
EXPECT_TRUE(mlir::dyn_cast<fir::ConvertOp>(v2.getDefiningOp()));
|
||||
}
|
|
@ -10,6 +10,7 @@ set(LIBS
|
|||
|
||||
add_flang_unittest(FlangOptimizerTests
|
||||
Builder/CharacterTest.cpp
|
||||
Builder/ComplexTest.cpp
|
||||
Builder/DoLoopHelperTest.cpp
|
||||
Builder/FIRBuilderTest.cpp
|
||||
FIRContextTest.cpp
|
||||
|
|
Loading…
Reference in New Issue