[fir] Add IfBuilder and utility functions

In order to reduct the size of D111337. The IfBuilder and the two
utility functions genIsNotNull and genIsNull have been extracted in
a separate patch with dedicated unittests.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: Leporacanthicus

Differential Revision: https://reviews.llvm.org/D111796

Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
Valentin Clement 2021-10-17 20:54:49 +02:00
parent dbf5dc8930
commit f17f694a0f
No known key found for this signature in database
GPG Key ID: 086D54783C928776
4 changed files with 197 additions and 0 deletions

View File

@ -38,6 +38,14 @@ public:
const fir::KindMapping &kindMap)
: OpBuilder{builder} {}
/// Get the integer type whose bit width corresponds to the width of pointer
/// types, or is bigger.
mlir::Type getIntPtrType() {
// TODO: Delay the need of such type until codegen or find a way to use
// llvm::DataLayout::getPointerSizeInBits here.
return getI64Type();
}
/// Create an integer constant of type \p type and value \p i.
mlir::Value createIntegerConstant(mlir::Location loc, mlir::Type integerType,
std::int64_t i);
@ -50,6 +58,74 @@ public:
mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) {
return createConvert(loc, getIndexType(), val);
}
//===--------------------------------------------------------------------===//
// If-Then-Else generation helper
//===--------------------------------------------------------------------===//
/// Helper class to create if-then-else in a structured way:
/// Usage: genIfOp().genThen([&](){...}).genElse([&](){...}).end();
/// Alternatively, getResults() can be used instead of end() to end the ifOp
/// and get the ifOp results.
class IfBuilder {
public:
IfBuilder(fir::IfOp ifOp, FirOpBuilder &builder)
: ifOp{ifOp}, builder{builder} {}
template <typename CC>
IfBuilder &genThen(CC func) {
builder.setInsertionPointToStart(&ifOp.thenRegion().front());
func();
return *this;
}
template <typename CC>
IfBuilder &genElse(CC func) {
assert(!ifOp.elseRegion().empty() && "must have else region");
builder.setInsertionPointToStart(&ifOp.elseRegion().front());
func();
return *this;
}
void end() { builder.setInsertionPointAfter(ifOp); }
/// End the IfOp and return the results if any.
mlir::Operation::result_range getResults() {
end();
return ifOp.getResults();
}
fir::IfOp &getIfOp() { return ifOp; };
private:
fir::IfOp ifOp;
FirOpBuilder &builder;
};
/// Create an IfOp and returns an IfBuilder that can generate the else/then
/// bodies.
IfBuilder genIfOp(mlir::Location loc, mlir::TypeRange results,
mlir::Value cdt, bool withElseRegion) {
auto op = create<fir::IfOp>(loc, results, cdt, withElseRegion);
return IfBuilder(op, *this);
}
/// Create an IfOp with no "else" region, and no result values.
/// Usage: genIfThen(loc, cdt).genThen(lambda).end();
IfBuilder genIfThen(mlir::Location loc, mlir::Value cdt) {
auto op = create<fir::IfOp>(loc, llvm::None, cdt, false);
return IfBuilder(op, *this);
}
/// Create an IfOp with an "else" region, and no result values.
/// Usage: genIfThenElse(loc, cdt).genThen(lambda).genElse(lambda).end();
IfBuilder genIfThenElse(mlir::Location loc, mlir::Value cdt) {
auto op = create<fir::IfOp>(loc, llvm::None, cdt, true);
return IfBuilder(op, *this);
}
/// Generate code testing \p addr is not a null address.
mlir::Value genIsNotNull(mlir::Location loc, mlir::Value addr);
/// Generate code testing \p addr is a null address.
mlir::Value genIsNull(mlir::Location loc, mlir::Value addr);
};
} // namespace fir

View File

@ -22,3 +22,22 @@ mlir::Value fir::FirOpBuilder::createConvert(mlir::Location loc,
}
return val;
}
static mlir::Value genNullPointerComparison(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value addr,
arith::CmpIPredicate condition) {
auto intPtrTy = builder.getIntPtrType();
auto ptrToInt = builder.createConvert(loc, intPtrTy, addr);
auto c0 = builder.createIntegerConstant(loc, intPtrTy, 0);
return builder.create<arith::CmpIOp>(loc, condition, ptrToInt, c0);
}
mlir::Value fir::FirOpBuilder::genIsNotNull(mlir::Location loc,
mlir::Value addr) {
return genNullPointerComparison(*this, loc, addr, arith::CmpIPredicate::ne);
}
mlir::Value fir::FirOpBuilder::genIsNull(mlir::Location loc, mlir::Value addr) {
return genNullPointerComparison(*this, loc, addr, arith::CmpIPredicate::eq);
}

View File

@ -0,0 +1,101 @@
//===- FIRBuilderTest.cpp -- FIRBuilder unit tests ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "gtest/gtest.h"
#include "flang/Optimizer/Support/InitFIR.h"
#include "flang/Optimizer/Support/KindMapping.h"
struct FIRBuilderTest : public testing::Test {
public:
void SetUp() override {
fir::KindMapping kindMap(&context);
mlir::OpBuilder builder(&context);
firBuilder = std::make_unique<fir::FirOpBuilder>(builder, kindMap);
fir::support::loadDialects(context);
}
fir::FirOpBuilder &getBuilder() { return *firBuilder; }
mlir::MLIRContext context;
std::unique_ptr<fir::FirOpBuilder> firBuilder;
};
static arith::CmpIOp createCondition(fir::FirOpBuilder &builder) {
auto loc = builder.getUnknownLoc();
auto zero1 = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
auto zero2 = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
return builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, zero1, zero2);
}
//===----------------------------------------------------------------------===//
// IfBuilder tests
//===----------------------------------------------------------------------===//
TEST_F(FIRBuilderTest, genIfThen) {
auto builder = getBuilder();
auto loc = builder.getUnknownLoc();
auto cdt = createCondition(builder);
auto ifBuilder = builder.genIfThen(loc, cdt);
EXPECT_FALSE(ifBuilder.getIfOp().thenRegion().empty());
EXPECT_TRUE(ifBuilder.getIfOp().elseRegion().empty());
}
TEST_F(FIRBuilderTest, genIfThenElse) {
auto builder = getBuilder();
auto loc = builder.getUnknownLoc();
auto cdt = createCondition(builder);
auto ifBuilder = builder.genIfThenElse(loc, cdt);
EXPECT_FALSE(ifBuilder.getIfOp().thenRegion().empty());
EXPECT_FALSE(ifBuilder.getIfOp().elseRegion().empty());
}
TEST_F(FIRBuilderTest, genIfWithThen) {
auto builder = getBuilder();
auto loc = builder.getUnknownLoc();
auto cdt = createCondition(builder);
auto ifBuilder = builder.genIfOp(loc, {}, cdt, false);
EXPECT_FALSE(ifBuilder.getIfOp().thenRegion().empty());
EXPECT_TRUE(ifBuilder.getIfOp().elseRegion().empty());
}
TEST_F(FIRBuilderTest, genIfWithThenAndElse) {
auto builder = getBuilder();
auto loc = builder.getUnknownLoc();
auto cdt = createCondition(builder);
auto ifBuilder = builder.genIfOp(loc, {}, cdt, true);
EXPECT_FALSE(ifBuilder.getIfOp().thenRegion().empty());
EXPECT_FALSE(ifBuilder.getIfOp().elseRegion().empty());
}
//===----------------------------------------------------------------------===//
// Helper functions tests
//===----------------------------------------------------------------------===//
TEST_F(FIRBuilderTest, genIsNotNull) {
auto builder = getBuilder();
auto loc = builder.getUnknownLoc();
auto dummyValue =
builder.createIntegerConstant(loc, builder.getIndexType(), 0);
auto res = builder.genIsNotNull(loc, dummyValue);
EXPECT_TRUE(mlir::isa<arith::CmpIOp>(res.getDefiningOp()));
auto cmpOp = dyn_cast<arith::CmpIOp>(res.getDefiningOp());
EXPECT_EQ(arith::CmpIPredicate::ne, cmpOp.predicate());
}
TEST_F(FIRBuilderTest, genIsNull) {
auto builder = getBuilder();
auto loc = builder.getUnknownLoc();
auto dummyValue =
builder.createIntegerConstant(loc, builder.getIndexType(), 0);
auto res = builder.genIsNull(loc, dummyValue);
EXPECT_TRUE(mlir::isa<arith::CmpIOp>(res.getDefiningOp()));
auto cmpOp = dyn_cast<arith::CmpIOp>(res.getDefiningOp());
EXPECT_EQ(arith::CmpIPredicate::eq, cmpOp.predicate());
}

View File

@ -10,6 +10,7 @@ set(LIBS
add_flang_unittest(FlangOptimizerTests
Builder/DoLoopHelperTest.cpp
Builder/FIRBuilderTest.cpp
FIRContextTest.cpp
InternalNamesTest.cpp
KindMappingTest.cpp