[flang] Lower basic derived types

This patch lowers basic derived type to FIR.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D121383

Co-authored-by: V Donaldson <vdonaldson@nvidia.com>
Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
Valentin Clement 2022-03-10 18:06:20 +01:00
parent 58966dd42b
commit 589d51ea9f
No known key found for this signature in database
GPG Key ID: 086D54783C928776
7 changed files with 331 additions and 22 deletions

View File

@ -137,8 +137,6 @@ public:
// Types
//===--------------------------------------------------------------------===//
/// Generate the type of a DataRef
virtual mlir::Type genType(const Fortran::evaluate::DataRef &) = 0;
/// Generate the type of an Expr
virtual mlir::Type genType(const SomeExpr &) = 0;
/// Generate the type of a Symbol
@ -149,6 +147,8 @@ public:
virtual mlir::Type
genType(Fortran::common::TypeCategory tc, int kind,
llvm::ArrayRef<std::int64_t> lenParameters = llvm::None) = 0;
/// Generate the type from a DerivedTypeSpec.
virtual mlir::Type genType(const Fortran::semantics::DerivedTypeSpec &) = 0;
/// Generate the type from a Variable
virtual mlir::Type genType(const pft::Variable &) = 0;

View File

@ -44,6 +44,7 @@ struct SomeType;
namespace semantics {
class Symbol;
class DerivedTypeSpec;
} // namespace semantics
namespace lower {
@ -62,6 +63,11 @@ using LenParameterTy = std::int64_t;
mlir::Type getFIRType(mlir::MLIRContext *ctxt, common::TypeCategory tc,
int kind, llvm::ArrayRef<LenParameterTy>);
/// Get a FIR type for a derived type
mlir::Type
translateDerivedTypeToFIRType(Fortran::lower::AbstractConverter &,
const Fortran::semantics::DerivedTypeSpec &);
/// Translate a SomeExpr to an mlir::Type.
mlir::Type translateSomeExprToFIRType(Fortran::lower::AbstractConverter &,
const SomeExpr &expr);

View File

@ -241,26 +241,26 @@ public:
return foldingContext;
}
mlir::Type genType(const Fortran::evaluate::DataRef &) override final {
TODO_NOLOC("Not implemented genType DataRef. Needed for more complex "
"expression lowering");
}
mlir::Type genType(const Fortran::lower::SomeExpr &expr) override final {
return Fortran::lower::translateSomeExprToFIRType(*this, expr);
}
mlir::Type genType(Fortran::lower::SymbolRef sym) override final {
return Fortran::lower::translateSymbolToFIRType(*this, sym);
}
mlir::Type genType(Fortran::common::TypeCategory tc) override final {
TODO_NOLOC("Not implemented genType TypeCategory. Needed for more complex "
"expression lowering");
}
mlir::Type
genType(Fortran::common::TypeCategory tc, int kind,
llvm::ArrayRef<std::int64_t> lenParameters) override final {
return Fortran::lower::getFIRType(&getMLIRContext(), tc, kind,
lenParameters);
}
mlir::Type
genType(const Fortran::semantics::DerivedTypeSpec &tySpec) override final {
return Fortran::lower::translateDerivedTypeToFIRType(*this, tySpec);
}
mlir::Type genType(Fortran::common::TypeCategory tc) override final {
TODO_NOLOC("Not implemented genType TypeCategory. Needed for more complex "
"expression lowering");
}
mlir::Type genType(const Fortran::lower::pft::Variable &var) override final {
return Fortran::lower::translateVariableToFIRType(*this, var);
}

View File

@ -215,7 +215,11 @@ void Fortran::lower::CallerInterface::walkResultLengths(
dynamicType.GetCharLength())
visitor(toEvExpr(*length));
} else if (dynamicType.category() == common::TypeCategory::Derived) {
TODO(converter.getCurrentLocation(), "walkResultLengths derived type");
const Fortran::semantics::DerivedTypeSpec &derivedTypeSpec =
dynamicType.GetDerivedTypeSpec();
if (Fortran::semantics::CountLenParameters(derivedTypeSpec) > 0)
TODO(converter.getCurrentLocation(),
"function result with derived type length parameters");
}
}
@ -759,8 +763,10 @@ private:
Fortran::common::TypeCategory cat = dynamicType.category();
// DERIVED
if (cat == Fortran::common::TypeCategory::Derived) {
TODO(interface.converter.getCurrentLocation(),
"[translateDynamicType] Derived types");
if (dynamicType.IsPolymorphic())
TODO(interface.converter.getCurrentLocation(),
"[translateDynamicType] polymorphic types");
return getConverter().genType(dynamicType.GetDerivedTypeSpec());
}
// CHARACTER with compile time constant length.
if (cat == Fortran::common::TypeCategory::Character)

View File

@ -1109,10 +1109,10 @@ public:
}
ExtValue gen(const Fortran::evaluate::DataRef &dref) {
TODO(getLoc(), "gen DataRef");
return std::visit([&](const auto &x) { return gen(x); }, dref.u);
}
ExtValue genval(const Fortran::evaluate::DataRef &dref) {
TODO(getLoc(), "genval DataRef");
return std::visit([&](const auto &x) { return genval(x); }, dref.u);
}
// Helper function to turn the Component structure into a list of nested
@ -1166,10 +1166,18 @@ public:
}
ExtValue gen(const Fortran::evaluate::Component &cmpt) {
TODO(getLoc(), "gen Component");
// Components may be pointer or allocatable. In the gen() path, the mutable
// aspect is lost to simplify handling on the client side. To retain the
// mutable aspect, genMutableBoxValue should be used.
return genComponent(cmpt).match(
[&](const fir::MutableBoxValue &mutableBox) {
return fir::factory::genMutableBoxRead(builder, getLoc(), mutableBox);
},
[](auto &box) -> ExtValue { return box; });
}
ExtValue genval(const Fortran::evaluate::Component &cmpt) {
TODO(getLoc(), "genval Component");
return genLoad(gen(cmpt));
}
ExtValue genval(const Fortran::semantics::Bound &bound) {
@ -1345,7 +1353,7 @@ public:
mlir::Type genType(const Fortran::evaluate::DynamicType &dt) {
if (dt.category() != Fortran::common::TypeCategory::Derived)
return converter.genType(dt.category(), dt.kind());
TODO(getLoc(), "genType Derived Type");
return converter.genType(dt.GetDerivedTypeSpec());
}
/// Lower a function reference

View File

@ -8,6 +8,7 @@
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/Support/Utils.h"
#include "flang/Lower/Todo.h"
@ -16,6 +17,7 @@
#include "flang/Semantics/type.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "flang-lower-type"
@ -139,7 +141,7 @@ public:
mlir::Type baseType;
if (category == Fortran::common::TypeCategory::Derived) {
TODO(converter.getCurrentLocation(), "genExprType derived");
baseType = genDerivedType(dynamicType->GetDerivedTypeSpec());
} else {
// LOGICAL, INTEGER, REAL, COMPLEX, CHARACTER
llvm::SmallVector<Fortran::lower::LenParameterTy> params;
@ -231,8 +233,9 @@ public:
ty = genFIRType(context, tySpec->category(), kind, params);
} else if (type->IsPolymorphic()) {
TODO(loc, "genSymbolType polymorphic types");
} else if (type->AsDerived()) {
TODO(loc, "genSymbolType derived type");
} else if (const Fortran::semantics::DerivedTypeSpec *tySpec =
type->AsDerived()) {
ty = genDerivedType(*tySpec);
} else {
fir::emitFatalError(loc, "symbol's type must have a type spec");
}
@ -263,6 +266,71 @@ public:
return ty;
}
/// Does \p component has non deferred lower bounds that are not compile time
/// constant 1.
static bool componentHasNonDefaultLowerBounds(
const Fortran::semantics::Symbol &component) {
if (const auto *objDetails =
component.detailsIf<Fortran::semantics::ObjectEntityDetails>())
for (const Fortran::semantics::ShapeSpec &bounds : objDetails->shape())
if (auto lb = bounds.lbound().GetExplicit())
if (auto constant = Fortran::evaluate::ToInt64(*lb))
if (!constant || *constant != 1)
return true;
return false;
}
mlir::Type genDerivedType(const Fortran::semantics::DerivedTypeSpec &tySpec) {
std::vector<std::pair<std::string, mlir::Type>> ps;
std::vector<std::pair<std::string, mlir::Type>> cs;
const Fortran::semantics::Symbol &typeSymbol = tySpec.typeSymbol();
if (mlir::Type ty = getTypeIfDerivedAlreadyInConstruction(typeSymbol))
return ty;
auto rec = fir::RecordType::get(context,
Fortran::lower::mangle::mangleName(tySpec));
// Maintain the stack of types for recursive references.
derivedTypeInConstruction.emplace_back(typeSymbol, rec);
// Gather the record type fields.
// (1) The data components.
for (const auto &field :
Fortran::semantics::OrderedComponentIterator(tySpec)) {
// Lowering is assuming non deferred component lower bounds are always 1.
// Catch any situations where this is not true for now.
if (componentHasNonDefaultLowerBounds(field))
TODO(converter.genLocation(field.name()),
"lowering derived type components with non default lower bounds");
if (IsProcName(field))
TODO(converter.genLocation(field.name()), "procedure components");
mlir::Type ty = genSymbolType(field);
// Do not add the parent component (component of the parents are
// added and should be sufficient, the parent component would
// duplicate the fields).
if (field.test(Fortran::semantics::Symbol::Flag::ParentComp))
continue;
cs.emplace_back(field.name().ToString(), ty);
}
// (2) The LEN type parameters.
for (const auto &param :
Fortran::semantics::OrderParameterDeclarations(typeSymbol))
if (param->get<Fortran::semantics::TypeParamDetails>().attr() ==
Fortran::common::TypeParamAttr::Len)
ps.emplace_back(param->name().ToString(), genSymbolType(*param));
rec.finalize(ps, cs);
popDerivedTypeInConstruction();
if (!ps.empty()) {
// This type is a PDT (parametric derived type). Create the functions to
// use for allocation, dereferencing, and address arithmetic here.
TODO(converter.genLocation(typeSymbol.name()),
"parametrized derived types lowering");
}
LLVM_DEBUG(llvm::dbgs() << "derived type: " << rec << '\n');
return rec;
}
// To get the character length from a symbol, make an fold a designator for
// the symbol to cover the case where the symbol is an assumed length named
// constant and its length comes from its init expression length.
@ -326,7 +394,27 @@ public:
return genSymbolType(var.getSymbol(), var.isHeapAlloc(), var.isPointer());
}
private:
/// Derived type can be recursive. That is, pointer components of a derived
/// type `t` have type `t`. This helper returns `t` if it is already being
/// lowered to avoid infinite loops.
mlir::Type getTypeIfDerivedAlreadyInConstruction(
const Fortran::lower::SymbolRef derivedSym) const {
for (const auto &[sym, type] : derivedTypeInConstruction)
if (sym == derivedSym)
return type;
return {};
}
void popDerivedTypeInConstruction() {
assert(!derivedTypeInConstruction.empty());
derivedTypeInConstruction.pop_back();
}
/// Stack derived type being processed to avoid infinite loops in case of
/// recursive derived types. The depth of derived types is expected to be
/// shallow (<10), so a SmallVector is sufficient.
llvm::SmallVector<std::pair<const Fortran::lower::SymbolRef, mlir::Type>>
derivedTypeInConstruction;
Fortran::lower::AbstractConverter &converter;
mlir::MLIRContext *context;
};
@ -340,6 +428,12 @@ mlir::Type Fortran::lower::getFIRType(mlir::MLIRContext *context,
return genFIRType(context, tc, kind, params);
}
mlir::Type Fortran::lower::translateDerivedTypeToFIRType(
Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::DerivedTypeSpec &tySpec) {
return TypeBuilder{converter}.genDerivedType(tySpec);
}
mlir::Type Fortran::lower::translateSomeExprToFIRType(
Fortran::lower::AbstractConverter &converter, const SomeExpr &expr) {
return TypeBuilder{converter}.genExprType(expr);

View File

@ -0,0 +1,195 @@
! Test basic parts of derived type entities lowering
! RUN: bbc -emit-fir %s -o - | FileCheck %s
! Note: only testing non parametrized derived type here.
module d
type r
real :: x
end type
type r2
real :: x_array(10, 20)
end type
type c
character(10) :: ch
end type
type c2
character(10) :: ch_array(20, 30)
end type
contains
! -----------------------------------------------------------------------------
! Test simple derived type symbol lowering
! -----------------------------------------------------------------------------
! CHECK-LABEL: func @_QMdPderived_dummy(
! CHECK-SAME: %{{.*}}: !fir.ref<!fir.type<_QMdTr{x:f32}>>{{.*}}, %{{.*}}: !fir.ref<!fir.type<_QMdTc2{ch_array:!fir.array<20x30x!fir.char<1,10>>}>>{{.*}}) {
subroutine derived_dummy(some_r, some_c2)
type(r) :: some_r
type(c2) :: some_c2
end subroutine
! CHECK-LABEL: func @_QMdPlocal_derived(
subroutine local_derived()
! CHECK-DAG: fir.alloca !fir.type<_QMdTc2{ch_array:!fir.array<20x30x!fir.char<1,10>>}>
! CHECK-DAG: fir.alloca !fir.type<_QMdTr{x:f32}>
type(r) :: some_r
type(c2) :: some_c2
end subroutine
! CHECK-LABEL: func @_QMdPsaved_derived(
subroutine saved_derived()
! CHECK-DAG: fir.address_of(@_QMdFsaved_derivedEsome_c2) : !fir.ref<!fir.type<_QMdTc2{ch_array:!fir.array<20x30x!fir.char<1,10>>}>>
! CHECK-DAG: fir.address_of(@_QMdFsaved_derivedEsome_r) : !fir.ref<!fir.type<_QMdTr{x:f32}>>
type(r), save :: some_r
type(c2), save :: some_c2
call use_symbols(some_r, some_c2)
end subroutine
! -----------------------------------------------------------------------------
! Test simple derived type references
! -----------------------------------------------------------------------------
! CHECK-LABEL: func @_QMdPscalar_numeric_ref(
subroutine scalar_numeric_ref()
! CHECK: %[[alloc:.*]] = fir.alloca !fir.type<_QMdTr{x:f32}>
type(r) :: some_r
! CHECK: %[[field:.*]] = fir.field_index x, !fir.type<_QMdTr{x:f32}>
! CHECK: fir.coordinate_of %[[alloc]], %[[field]] : (!fir.ref<!fir.type<_QMdTr{x:f32}>>, !fir.field) -> !fir.ref<f32>
call real_bar(some_r%x)
end subroutine
! CHECK-LABEL: func @_QMdPscalar_character_ref(
subroutine scalar_character_ref()
! CHECK: %[[alloc:.*]] = fir.alloca !fir.type<_QMdTc{ch:!fir.char<1,10>}>
type(c) :: some_c
! CHECK: %[[field:.*]] = fir.field_index ch, !fir.type<_QMdTc{ch:!fir.char<1,10>}>
! CHECK: %[[coor:.*]] = fir.coordinate_of %[[alloc]], %[[field]] : (!fir.ref<!fir.type<_QMdTc{ch:!fir.char<1,10>}>>, !fir.field) -> !fir.ref<!fir.char<1,10>>
! CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index
! CHECK-DAG: %[[conv:.*]] = fir.convert %[[coor]] : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<!fir.char<1,?>>
! CHECK: fir.emboxchar %[[conv]], %c10 : (!fir.ref<!fir.char<1,?>>, index) -> !fir.boxchar<1>
call char_bar(some_c%ch)
end subroutine
! FIXME: coordinate of generated for derived%array_comp(i) are not zero based as they
! should be.
! CHECK-LABEL: func @_QMdParray_comp_elt_ref(
subroutine array_comp_elt_ref()
type(r2) :: some_r2
! CHECK: %[[alloc:.*]] = fir.alloca !fir.type<_QMdTr2{x_array:!fir.array<10x20xf32>}>
! CHECK: %[[field:.*]] = fir.field_index x_array, !fir.type<_QMdTr2{x_array:!fir.array<10x20xf32>}>
! CHECK: %[[coor:.*]] = fir.coordinate_of %[[alloc]], %[[field]] : (!fir.ref<!fir.type<_QMdTr2{x_array:!fir.array<10x20xf32>}>>, !fir.field) -> !fir.ref<!fir.array<10x20xf32>>
! CHECK-DAG: %[[index1:.*]] = arith.subi %c5{{.*}}, %c1{{.*}} : i64
! CHECK-DAG: %[[index2:.*]] = arith.subi %c6{{.*}}, %c1{{.*}} : i64
! CHECK: fir.coordinate_of %[[coor]], %[[index1]], %[[index2]] : (!fir.ref<!fir.array<10x20xf32>>, i64, i64) -> !fir.ref<f32>
call real_bar(some_r2%x_array(5, 6))
end subroutine
! CHECK-LABEL: func @_QMdPchar_array_comp_elt_ref(
subroutine char_array_comp_elt_ref()
type(c2) :: some_c2
! CHECK: %[[coor:.*]] = fir.coordinate_of %{{.*}}, %{{.*}} : (!fir.ref<!fir.type<_QMdTc2{ch_array:!fir.array<20x30x!fir.char<1,10>>}>>, !fir.field) -> !fir.ref<!fir.array<20x30x!fir.char<1,10>>>
! CHECK-DAG: %[[index1:.*]] = arith.subi %c5{{.*}}, %c1{{.*}} : i64
! CHECK-DAG: %[[index2:.*]] = arith.subi %c6{{.*}}, %c1{{.*}} : i64
! CHECK: fir.coordinate_of %[[coor]], %[[index1]], %[[index2]] : (!fir.ref<!fir.array<20x30x!fir.char<1,10>>>, i64, i64) -> !fir.ref<!fir.char<1,10>>
! CHECK: fir.emboxchar %{{.*}}, %c10 : (!fir.ref<!fir.char<1,?>>, index) -> !fir.boxchar<1>
call char_bar(some_c2%ch_array(5, 6))
end subroutine
! CHECK: @_QMdParray_elt_comp_ref
subroutine array_elt_comp_ref()
type(r) :: some_r_array(100)
! CHECK: %[[alloca:.*]] = fir.alloca !fir.array<100x!fir.type<_QMdTr{x:f32}>>
! CHECK: %[[index:.*]] = arith.subi %c5{{.*}}, %c1{{.*}} : i64
! CHECK: %[[elt:.*]] = fir.coordinate_of %[[alloca]], %[[index]] : (!fir.ref<!fir.array<100x!fir.type<_QMdTr{x:f32}>>>, i64) -> !fir.ref<!fir.type<_QMdTr{x:f32}>>
! CHECK: %[[field:.*]] = fir.field_index x, !fir.type<_QMdTr{x:f32}>
! CHECK: fir.coordinate_of %[[elt]], %[[field]] : (!fir.ref<!fir.type<_QMdTr{x:f32}>>, !fir.field) -> !fir.ref<f32>
call real_bar(some_r_array(5)%x)
end subroutine
! CHECK: @_QMdPchar_array_elt_comp_ref
subroutine char_array_elt_comp_ref()
type(c) :: some_c_array(100)
! CHECK: fir.coordinate_of %{{.*}}, %{{.*}} : (!fir.ref<!fir.array<100x!fir.type<_QMdTc{ch:!fir.char<1,10>}>>>, i64) -> !fir.ref<!fir.type<_QMdTc{ch:!fir.char<1,10>}>>
! CHECK: fir.coordinate_of %{{.*}}, %{{.*}} : (!fir.ref<!fir.type<_QMdTc{ch:!fir.char<1,10>}>>, !fir.field) -> !fir.ref<!fir.char<1,10>>
! CHECK: fir.emboxchar %{{.*}}, %c10{{.*}} : (!fir.ref<!fir.char<1,?>>, index) -> !fir.boxchar<1>
call char_bar(some_c_array(5)%ch)
end subroutine
! -----------------------------------------------------------------------------
! Test loading derived type components
! -----------------------------------------------------------------------------
! Most of the other tests only require lowering code to compute the address of
! components. This one requires loading a component which tests other code paths
! in lowering.
! CHECK-LABEL: func @_QMdPscalar_numeric_load(
! CHECK-SAME: %[[arg0:.*]]: !fir.ref<!fir.type<_QMdTr{x:f32}>>
real function scalar_numeric_load(some_r)
type(r) :: some_r
! CHECK: %[[field:.*]] = fir.field_index x, !fir.type<_QMdTr{x:f32}>
! CHECK: %[[coor:.*]] = fir.coordinate_of %[[arg0]], %[[field]] : (!fir.ref<!fir.type<_QMdTr{x:f32}>>, !fir.field) -> !fir.ref<f32>
! CHECK: fir.load %[[coor]]
scalar_numeric_load = some_r%x
end function
! -----------------------------------------------------------------------------
! Test returned derived types (no length parameters)
! -----------------------------------------------------------------------------
! CHECK-LABEL: func @_QMdPbar_return_derived() -> !fir.type<_QMdTr{x:f32}>
function bar_return_derived()
! CHECK: %[[res:.*]] = fir.alloca !fir.type<_QMdTr{x:f32}>
type(r) :: bar_return_derived
! CHECK: %[[resLoad:.*]] = fir.load %[[res]] : !fir.ref<!fir.type<_QMdTr{x:f32}>>
! CHECK: return %[[resLoad]] : !fir.type<_QMdTr{x:f32}>
end function
! CHECK-LABEL: func @_QMdPcall_bar_return_derived(
subroutine call_bar_return_derived()
! CHECK: %[[tmp:.*]] = fir.alloca !fir.type<_QMdTr{x:f32}>
! CHECK: %[[call:.*]] = fir.call @_QMdPbar_return_derived() : () -> !fir.type<_QMdTr{x:f32}>
! CHECK: fir.save_result %[[call]] to %[[tmp]] : !fir.type<_QMdTr{x:f32}>, !fir.ref<!fir.type<_QMdTr{x:f32}>>
! CHECK: fir.call @_QPr_bar(%[[tmp]]) : (!fir.ref<!fir.type<_QMdTr{x:f32}>>) -> ()
call r_bar(bar_return_derived())
end subroutine
end module
! -----------------------------------------------------------------------------
! Test derived type with pointer/allocatable components
! -----------------------------------------------------------------------------
module d2
type recursive_t
real :: x
type(recursive_t), pointer :: ptr
end type
contains
! CHECK-LABEL: func @_QMd2Ptest_recursive_type(
! CHECK-SAME: %{{.*}}: !fir.ref<!fir.type<_QMd2Trecursive_t{x:f32,ptr:!fir.box<!fir.ptr<!fir.type<_QMd2Trecursive_t>>>}>>{{.*}}) {
subroutine test_recursive_type(some_recursive)
type(recursive_t) :: some_recursive
end subroutine
end module
! -----------------------------------------------------------------------------
! Test global derived type symbol lowering
! -----------------------------------------------------------------------------
module data_mod
use d
type(r) :: some_r
type(c2) :: some_c2
end module
! Test globals
! CHECK-DAG: fir.global @_QMdata_modEsome_c2 : !fir.type<_QMdTc2{ch_array:!fir.array<20x30x!fir.char<1,10>>}>
! CHECK-DAG: fir.global @_QMdata_modEsome_r : !fir.type<_QMdTr{x:f32}>
! CHECK-DAG: fir.global internal @_QMdFsaved_derivedEsome_c2 : !fir.type<_QMdTc2{ch_array:!fir.array<20x30x!fir.char<1,10>>}>
! CHECK-DAG: fir.global internal @_QMdFsaved_derivedEsome_r : !fir.type<_QMdTr{x:f32}>