diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h index 2d9cc2a78bc0..900a974819dc 100644 --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -38,6 +38,14 @@ public: const fir::KindMapping &kindMap) : OpBuilder{builder} {} + /// Get the integer type whose bit width corresponds to the width of pointer + /// types, or is bigger. + mlir::Type getIntPtrType() { + // TODO: Delay the need of such type until codegen or find a way to use + // llvm::DataLayout::getPointerSizeInBits here. + return getI64Type(); + } + /// Create an integer constant of type \p type and value \p i. mlir::Value createIntegerConstant(mlir::Location loc, mlir::Type integerType, std::int64_t i); @@ -50,6 +58,74 @@ public: mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) { return createConvert(loc, getIndexType(), val); } + + //===--------------------------------------------------------------------===// + // If-Then-Else generation helper + //===--------------------------------------------------------------------===// + + /// Helper class to create if-then-else in a structured way: + /// Usage: genIfOp().genThen([&](){...}).genElse([&](){...}).end(); + /// Alternatively, getResults() can be used instead of end() to end the ifOp + /// and get the ifOp results. + class IfBuilder { + public: + IfBuilder(fir::IfOp ifOp, FirOpBuilder &builder) + : ifOp{ifOp}, builder{builder} {} + template + IfBuilder &genThen(CC func) { + builder.setInsertionPointToStart(&ifOp.thenRegion().front()); + func(); + return *this; + } + template + IfBuilder &genElse(CC func) { + assert(!ifOp.elseRegion().empty() && "must have else region"); + builder.setInsertionPointToStart(&ifOp.elseRegion().front()); + func(); + return *this; + } + void end() { builder.setInsertionPointAfter(ifOp); } + + /// End the IfOp and return the results if any. + mlir::Operation::result_range getResults() { + end(); + return ifOp.getResults(); + } + + fir::IfOp &getIfOp() { return ifOp; }; + + private: + fir::IfOp ifOp; + FirOpBuilder &builder; + }; + + /// Create an IfOp and returns an IfBuilder that can generate the else/then + /// bodies. + IfBuilder genIfOp(mlir::Location loc, mlir::TypeRange results, + mlir::Value cdt, bool withElseRegion) { + auto op = create(loc, results, cdt, withElseRegion); + return IfBuilder(op, *this); + } + + /// Create an IfOp with no "else" region, and no result values. + /// Usage: genIfThen(loc, cdt).genThen(lambda).end(); + IfBuilder genIfThen(mlir::Location loc, mlir::Value cdt) { + auto op = create(loc, llvm::None, cdt, false); + return IfBuilder(op, *this); + } + + /// Create an IfOp with an "else" region, and no result values. + /// Usage: genIfThenElse(loc, cdt).genThen(lambda).genElse(lambda).end(); + IfBuilder genIfThenElse(mlir::Location loc, mlir::Value cdt) { + auto op = create(loc, llvm::None, cdt, true); + return IfBuilder(op, *this); + } + + /// Generate code testing \p addr is not a null address. + mlir::Value genIsNotNull(mlir::Location loc, mlir::Value addr); + + /// Generate code testing \p addr is a null address. + mlir::Value genIsNull(mlir::Location loc, mlir::Value addr); }; } // namespace fir diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index b800ecaeb5ad..6f98e6029916 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -22,3 +22,22 @@ mlir::Value fir::FirOpBuilder::createConvert(mlir::Location loc, } return val; } + +static mlir::Value genNullPointerComparison(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value addr, + arith::CmpIPredicate condition) { + auto intPtrTy = builder.getIntPtrType(); + auto ptrToInt = builder.createConvert(loc, intPtrTy, addr); + auto c0 = builder.createIntegerConstant(loc, intPtrTy, 0); + return builder.create(loc, condition, ptrToInt, c0); +} + +mlir::Value fir::FirOpBuilder::genIsNotNull(mlir::Location loc, + mlir::Value addr) { + return genNullPointerComparison(*this, loc, addr, arith::CmpIPredicate::ne); +} + +mlir::Value fir::FirOpBuilder::genIsNull(mlir::Location loc, mlir::Value addr) { + return genNullPointerComparison(*this, loc, addr, arith::CmpIPredicate::eq); +} diff --git a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp new file mode 100644 index 000000000000..07ae0bcd3386 --- /dev/null +++ b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp @@ -0,0 +1,101 @@ +//===- FIRBuilderTest.cpp -- FIRBuilder unit tests ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "gtest/gtest.h" +#include "flang/Optimizer/Support/InitFIR.h" +#include "flang/Optimizer/Support/KindMapping.h" + +struct FIRBuilderTest : public testing::Test { +public: + void SetUp() override { + fir::KindMapping kindMap(&context); + mlir::OpBuilder builder(&context); + firBuilder = std::make_unique(builder, kindMap); + fir::support::loadDialects(context); + } + + fir::FirOpBuilder &getBuilder() { return *firBuilder; } + + mlir::MLIRContext context; + std::unique_ptr firBuilder; +}; + +static arith::CmpIOp createCondition(fir::FirOpBuilder &builder) { + auto loc = builder.getUnknownLoc(); + auto zero1 = builder.createIntegerConstant(loc, builder.getIndexType(), 0); + auto zero2 = builder.createIntegerConstant(loc, builder.getIndexType(), 0); + return builder.create( + loc, arith::CmpIPredicate::eq, zero1, zero2); +} + +//===----------------------------------------------------------------------===// +// IfBuilder tests +//===----------------------------------------------------------------------===// + +TEST_F(FIRBuilderTest, genIfThen) { + auto builder = getBuilder(); + auto loc = builder.getUnknownLoc(); + auto cdt = createCondition(builder); + auto ifBuilder = builder.genIfThen(loc, cdt); + EXPECT_FALSE(ifBuilder.getIfOp().thenRegion().empty()); + EXPECT_TRUE(ifBuilder.getIfOp().elseRegion().empty()); +} + +TEST_F(FIRBuilderTest, genIfThenElse) { + auto builder = getBuilder(); + auto loc = builder.getUnknownLoc(); + auto cdt = createCondition(builder); + auto ifBuilder = builder.genIfThenElse(loc, cdt); + EXPECT_FALSE(ifBuilder.getIfOp().thenRegion().empty()); + EXPECT_FALSE(ifBuilder.getIfOp().elseRegion().empty()); +} + +TEST_F(FIRBuilderTest, genIfWithThen) { + auto builder = getBuilder(); + auto loc = builder.getUnknownLoc(); + auto cdt = createCondition(builder); + auto ifBuilder = builder.genIfOp(loc, {}, cdt, false); + EXPECT_FALSE(ifBuilder.getIfOp().thenRegion().empty()); + EXPECT_TRUE(ifBuilder.getIfOp().elseRegion().empty()); +} + +TEST_F(FIRBuilderTest, genIfWithThenAndElse) { + auto builder = getBuilder(); + auto loc = builder.getUnknownLoc(); + auto cdt = createCondition(builder); + auto ifBuilder = builder.genIfOp(loc, {}, cdt, true); + EXPECT_FALSE(ifBuilder.getIfOp().thenRegion().empty()); + EXPECT_FALSE(ifBuilder.getIfOp().elseRegion().empty()); +} + +//===----------------------------------------------------------------------===// +// Helper functions tests +//===----------------------------------------------------------------------===// + +TEST_F(FIRBuilderTest, genIsNotNull) { + auto builder = getBuilder(); + auto loc = builder.getUnknownLoc(); + auto dummyValue = + builder.createIntegerConstant(loc, builder.getIndexType(), 0); + auto res = builder.genIsNotNull(loc, dummyValue); + EXPECT_TRUE(mlir::isa(res.getDefiningOp())); + auto cmpOp = dyn_cast(res.getDefiningOp()); + EXPECT_EQ(arith::CmpIPredicate::ne, cmpOp.predicate()); +} + +TEST_F(FIRBuilderTest, genIsNull) { + auto builder = getBuilder(); + auto loc = builder.getUnknownLoc(); + auto dummyValue = + builder.createIntegerConstant(loc, builder.getIndexType(), 0); + auto res = builder.genIsNull(loc, dummyValue); + EXPECT_TRUE(mlir::isa(res.getDefiningOp())); + auto cmpOp = dyn_cast(res.getDefiningOp()); + EXPECT_EQ(arith::CmpIPredicate::eq, cmpOp.predicate()); +} diff --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt index aea74823bbe9..a8168fd4a334 100644 --- a/flang/unittests/Optimizer/CMakeLists.txt +++ b/flang/unittests/Optimizer/CMakeLists.txt @@ -10,6 +10,7 @@ set(LIBS add_flang_unittest(FlangOptimizerTests Builder/DoLoopHelperTest.cpp + Builder/FIRBuilderTest.cpp FIRContextTest.cpp InternalNamesTest.cpp KindMappingTest.cpp