[fir] Add the DoLoopHelper

Add the DoLoopHelper. Some helpers functions
to create fir.do_loop operations.

This code was part of D111337 and was extracted in order to
make the patch easier to review.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D111713

Co-authored-by: Valentin Clement <clementval@gmail.com>
Co-authored-by: Jean Perier <jperier@nvidia.com>
This commit is contained in:
Eric Schweitz 2021-10-13 21:46:18 +02:00 committed by Valentin Clement
parent a8a64eaafc
commit bde89ac7f1
No known key found for this signature in database
GPG Key ID: 086D54783C928776
8 changed files with 284 additions and 0 deletions

View File

@ -0,0 +1,50 @@
//===-- DoLoopHelper.h -- gen fir.do_loop ops -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#ifndef FORTRAN_OPTIMIZER_BUILDER_DOLOOPHELPER_H
#define FORTRAN_OPTIMIZER_BUILDER_DOLOOPHELPER_H
#include "flang/Optimizer/Builder/FIRBuilder.h"
namespace fir::factory {
/// Helper to build fir.do_loop Ops.
class DoLoopHelper {
public:
explicit DoLoopHelper(fir::FirOpBuilder &builder, mlir::Location loc)
: builder(builder), loc(loc) {}
DoLoopHelper(const DoLoopHelper &) = delete;
/// Type of a callback to generate the loop body.
using BodyGenerator = std::function<void(fir::FirOpBuilder &, mlir::Value)>;
/// Build loop [\p lb, \p ub] with step \p step.
/// If \p step is an empty value, 1 is used for the step.
fir::DoLoopOp createLoop(mlir::Value lb, mlir::Value ub, mlir::Value step,
const BodyGenerator &bodyGenerator);
/// Build loop [\p lb, \p ub] with step 1.
fir::DoLoopOp createLoop(mlir::Value lb, mlir::Value ub,
const BodyGenerator &bodyGenerator);
/// Build loop [0, \p count) with step 1.
fir::DoLoopOp createLoop(mlir::Value count,
const BodyGenerator &bodyGenerator);
private:
fir::FirOpBuilder &builder;
mlir::Location loc;
};
} // namespace fir::factory
#endif // FORTRAN_OPTIMIZER_BUILDER_DOLOOPHELPER_H

View File

@ -0,0 +1,60 @@
//===-- FirBuilder.h -- FIR operation builder -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Builder routines for constructing the FIR dialect of MLIR. As FIR is a
// dialect of MLIR, it makes extensive use of MLIR interfaces and MLIR's coding
// style (https://mlir.llvm.org/getting_started/DeveloperGuide/) is used in this
// module.
//
//===----------------------------------------------------------------------===//
#ifndef FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
#define FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/KindMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
namespace fir {
//===----------------------------------------------------------------------===//
// FirOpBuilder
//===----------------------------------------------------------------------===//
/// Extends the MLIR OpBuilder to provide methods for building common FIR
/// patterns.
class FirOpBuilder : public mlir::OpBuilder {
public:
explicit FirOpBuilder(mlir::Operation *op, const fir::KindMapping &kindMap)
: OpBuilder{op}, kindMap{kindMap} {}
explicit FirOpBuilder(mlir::OpBuilder &builder,
const fir::KindMapping &kindMap)
: OpBuilder{builder}, kindMap{kindMap} {}
/// Create an integer constant of type \p type and value \p i.
mlir::Value createIntegerConstant(mlir::Location loc, mlir::Type integerType,
std::int64_t i);
/// Lazy creation of fir.convert op.
mlir::Value createConvert(mlir::Location loc, mlir::Type toTy,
mlir::Value val);
/// Cast the input value to IndexType.
mlir::Value convertToIndexType(mlir::Location loc, mlir::Value val) {
return createConvert(loc, getIndexType(), val);
}
private:
const KindMapping &kindMap;
};
} // namespace fir
#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H

View File

@ -0,0 +1,15 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
add_flang_library(FIRBuilder
DoLoopHelper.cpp
FIRBuilder.cpp
DEPENDS
FIRDialect
FIRSupport
${dialect_libs}
LINK_LIBS
FIRDialect
FIRSupport
${dialect_libs}
)

View File

@ -0,0 +1,48 @@
//===-- DoLoopHelper.cpp --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/DoLoopHelper.h"
//===----------------------------------------------------------------------===//
// DoLoopHelper implementation
//===----------------------------------------------------------------------===//
fir::DoLoopOp
fir::factory::DoLoopHelper::createLoop(mlir::Value lb, mlir::Value ub,
mlir::Value step,
const BodyGenerator &bodyGenerator) {
auto lbi = builder.convertToIndexType(loc, lb);
auto ubi = builder.convertToIndexType(loc, ub);
assert(step && "step must be an actual Value");
auto inc = builder.convertToIndexType(loc, step);
auto loop = builder.create<fir::DoLoopOp>(loc, lbi, ubi, inc);
auto insertPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loop.getBody());
auto index = loop.getInductionVar();
bodyGenerator(builder, index);
builder.restoreInsertionPoint(insertPt);
return loop;
}
fir::DoLoopOp
fir::factory::DoLoopHelper::createLoop(mlir::Value lb, mlir::Value ub,
const BodyGenerator &bodyGenerator) {
return createLoop(
lb, ub, builder.createIntegerConstant(loc, builder.getIndexType(), 1),
bodyGenerator);
}
fir::DoLoopOp
fir::factory::DoLoopHelper::createLoop(mlir::Value count,
const BodyGenerator &bodyGenerator) {
auto indexType = builder.getIndexType();
auto zero = builder.createIntegerConstant(loc, indexType, 0);
auto one = builder.createIntegerConstant(loc, count.getType(), 1);
auto up = builder.create<arith::SubIOp>(loc, count, one);
return createLoop(zero, up, one, bodyGenerator);
}

View File

@ -0,0 +1,24 @@
//===-- FIRBuilder.cpp ----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/FIRBuilder.h"
mlir::Value fir::FirOpBuilder::createIntegerConstant(mlir::Location loc,
mlir::Type ty,
std::int64_t cst) {
return create<mlir::ConstantOp>(loc, ty, getIntegerAttr(ty, cst));
}
mlir::Value fir::FirOpBuilder::createConvert(mlir::Location loc,
mlir::Type toTy, mlir::Value val) {
if (val.getType() != toTy) {
assert(!fir::isa_derived(toTy));
return create<fir::ConvertOp>(loc, toTy, val);
}
return val;
}

View File

@ -1,3 +1,4 @@
add_subdirectory(Builder)
add_subdirectory(CodeGen)
add_subdirectory(Dialect)
add_subdirectory(Support)

View File

@ -0,0 +1,84 @@
//===- DoLoopHelper.cpp -- DoLoopHelper unit tests ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/DoLoopHelper.h"
#include "gtest/gtest.h"
#include "flang/Optimizer/Support/InitFIR.h"
#include "flang/Optimizer/Support/KindMapping.h"
#include <string>
struct DoLoopHelperTest : public testing::Test {
public:
void SetUp() {
fir::KindMapping kindMap(&context);
mlir::OpBuilder builder(&context);
firBuilder = new fir::FirOpBuilder(builder, kindMap);
fir::support::loadDialects(context);
}
void TearDown() { delete firBuilder; }
fir::FirOpBuilder &getBuilder() { return *firBuilder; }
mlir::MLIRContext context;
fir::FirOpBuilder *firBuilder;
};
void checkConstantValue(const mlir::Value &value, int64_t v) {
EXPECT_TRUE(mlir::isa<ConstantOp>(value.getDefiningOp()));
auto cstOp = dyn_cast<ConstantOp>(value.getDefiningOp());
auto valueAttr = cstOp.getValue().dyn_cast_or_null<IntegerAttr>();
EXPECT_EQ(v, valueAttr.getInt());
}
TEST_F(DoLoopHelperTest, createLoopWithCountTest) {
auto firBuilder = getBuilder();
fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
auto c10 = firBuilder.createIntegerConstant(
firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 10);
auto loop =
helper.createLoop(c10, [&](fir::FirOpBuilder &, mlir::Value index) {});
checkConstantValue(loop.lowerBound(), 0);
EXPECT_TRUE(mlir::isa<arith::SubIOp>(loop.upperBound().getDefiningOp()));
auto subOp = dyn_cast<arith::SubIOp>(loop.upperBound().getDefiningOp());
EXPECT_EQ(c10, subOp.lhs());
checkConstantValue(subOp.rhs(), 1);
checkConstantValue(loop.step(), 1);
}
TEST_F(DoLoopHelperTest, createLoopWithLowerAndUpperBound) {
auto firBuilder = getBuilder();
fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
auto lb = firBuilder.createIntegerConstant(
firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1);
auto ub = firBuilder.createIntegerConstant(
firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20);
auto loop =
helper.createLoop(lb, ub, [&](fir::FirOpBuilder &, mlir::Value index) {});
checkConstantValue(loop.lowerBound(), 1);
checkConstantValue(loop.upperBound(), 20);
checkConstantValue(loop.step(), 1);
}
TEST_F(DoLoopHelperTest, createLoopWithStep) {
auto firBuilder = getBuilder();
fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
auto lb = firBuilder.createIntegerConstant(
firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1);
auto ub = firBuilder.createIntegerConstant(
firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20);
auto step = firBuilder.createIntegerConstant(
firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 2);
auto loop = helper.createLoop(
lb, ub, step, [&](fir::FirOpBuilder &, mlir::Value index) {});
checkConstantValue(loop.lowerBound(), 1);
checkConstantValue(loop.upperBound(), 20);
checkConstantValue(loop.step(), 2);
}

View File

@ -1,6 +1,7 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
set(LIBS
FIRBuilder
FIRCodeGen
FIRDialect
FIRSupport
@ -8,6 +9,7 @@ set(LIBS
)
add_flang_unittest(FlangOptimizerTests
Builder/DoLoopHelperTest.cpp
FIRContextTest.cpp
InternalNamesTest.cpp
KindMappingTest.cpp