[flang] Fix assert on constant folding of extended types

When we define a derived type that extends another derived type, we can then
create a structure constructor that contains values for the fields of both the
child type and its parent.  The compiler's internal representation of that
value contains the name of the parent type where a component name would
normally appear.  This caused an assert during contant folding.

There are three cases for components that appear in structure constructors.
The first is the normal case of a component appearing in a structure
constructor for its type.

  The second is a component of the parent (or grandparent) type appearing in a
  structure constructor for the child type.

  The third is the parent type component, which can appear in the structure
  constructor of its child.

There are also cases where the component can be arrays.

I created the test case folding12.f90 that covers all of these cases and
modified the code to handle them.

Most of my changes were to the "Find()" method of the type
"StructureConstructor" where I added code to cover the second and third cases
described above.  To handle these cases, I needed to create a
"StructureConstructor" for the parent type component and return it.  To handle
returning a newly created "StructureConstructor", I changed the return type of
"Find()" to be "std::optional" rather than an ordinary pointer.

This change supersedes D86172.

Differential Revision: https://reviews.llvm.org/D87151
This commit is contained in:
Peter Steinfeld 2020-09-04 08:44:52 -07:00
parent 485f3f35cc
commit b34f116856
6 changed files with 245 additions and 11 deletions

View File

@ -717,7 +717,8 @@ public:
return values_.end();
}
const Expr<SomeType> *Find(const Symbol &) const; // can return null
// can return nullopt
std::optional<Expr<SomeType>> Find(const Symbol &) const;
StructureConstructor &Add(const semantics::Symbol &, Expr<SomeType> &&);
int Rank() const { return 0; }
@ -725,6 +726,7 @@ public:
llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;
private:
std::optional<Expr<SomeType>> CreateParentComponent(const Symbol &) const;
Result result_;
StructureConstructorValues values_;
};

View File

@ -217,6 +217,8 @@ private:
const semantics::DerivedTypeSpec *GetDerivedTypeSpec(const DynamicType &);
const semantics::DerivedTypeSpec *GetDerivedTypeSpec(
const std::optional<DynamicType> &);
const semantics::DerivedTypeSpec *GetParentTypeSpec(
const semantics::DerivedTypeSpec &);
std::string DerivedTypeSpecAsFortran(const semantics::DerivedTypeSpec &);

View File

@ -12,7 +12,12 @@
#include "flang/Evaluate/common.h"
#include "flang/Evaluate/tools.h"
#include "flang/Evaluate/variable.h"
#include "flang/Parser/char-block.h"
#include "flang/Parser/message.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "flang/Semantics/type.h"
#include "llvm/Support/raw_ostream.h"
#include <string>
#include <type_traits>
@ -206,13 +211,75 @@ bool Expr<SomeType>::operator==(const Expr<SomeType> &that) const {
DynamicType StructureConstructor::GetType() const { return result_.GetType(); }
const Expr<SomeType> *StructureConstructor::Find(
std::optional<Expr<SomeType>> StructureConstructor::CreateParentComponent(
const Symbol &component) const {
if (const semantics::DerivedTypeSpec *
parentSpec{GetParentTypeSpec(derivedTypeSpec())}) {
StructureConstructor structureConstructor{*parentSpec};
if (const auto *parentDetails{
component.detailsIf<semantics::DerivedTypeDetails>()}) {
auto parentIter{parentDetails->componentNames().begin()};
for (const auto &childIter : values_) {
if (parentIter == parentDetails->componentNames().end()) {
break; // There are more components in the child
}
SymbolRef componentSymbol{childIter.first};
structureConstructor.Add(
*componentSymbol, common::Clone(childIter.second.value()));
++parentIter;
}
Constant<SomeDerived> constResult{std::move(structureConstructor)};
Expr<SomeDerived> result{std::move(constResult)};
return std::optional<Expr<SomeType>>{result};
}
}
return std::nullopt;
}
static const Symbol *GetParentComponentSymbol(const Symbol &symbol) {
if (symbol.test(Symbol::Flag::ParentComp)) {
// we have a created parent component
const auto &compObject{symbol.get<semantics::ObjectEntityDetails>()};
if (const semantics::DeclTypeSpec * compType{compObject.type()}) {
const semantics::DerivedTypeSpec &dtSpec{compType->derivedTypeSpec()};
const semantics::Symbol &compTypeSymbol{dtSpec.typeSymbol()};
return &compTypeSymbol;
}
}
if (symbol.detailsIf<semantics::DerivedTypeDetails>()) {
// we have an implicit parent type component
return &symbol;
}
return nullptr;
}
std::optional<Expr<SomeType>> StructureConstructor::Find(
const Symbol &component) const {
if (auto iter{values_.find(component)}; iter != values_.end()) {
return &iter->second.value();
} else {
return nullptr;
return iter->second.value();
}
// The component wasn't there directly, see if we're looking for the parent
// component of an extended type
if (const Symbol * typeSymbol{GetParentComponentSymbol(component)}) {
return CreateParentComponent(*typeSymbol);
}
// Look for the component in the parent type component. The parent type
// component is always the first one
if (!values_.empty()) {
const Expr<SomeType> *parentExpr{&values_.begin()->second.value()};
if (const Expr<SomeDerived> *derivedExpr{
std::get_if<Expr<SomeDerived>>(&parentExpr->u)}) {
if (const Constant<SomeDerived> *constExpr{
std::get_if<Constant<SomeDerived>>(&derivedExpr->u)}) {
if (std::optional<StructureConstructor> parentComponentValue{
constExpr->GetScalarValue()}) {
// Try to find the component in the parent structure constructor
return parentComponentValue->Find(component);
}
}
}
}
return std::nullopt;
}
StructureConstructor &StructureConstructor::Add(

View File

@ -296,8 +296,8 @@ std::optional<Constant<T>> Folder<T>::ApplyComponent(
Constant<SomeDerived> &&structures, const Symbol &component,
const std::vector<Constant<SubscriptInteger>> *subscripts) {
if (auto scalar{structures.GetScalarValue()}) {
if (auto *expr{scalar->Find(component)}) {
if (const Constant<T> *value{UnwrapConstantValue<T>(*expr)}) {
if (std::optional<Expr<SomeType>> expr{scalar->Find(component)}) {
if (const Constant<T> *value{UnwrapConstantValue<T>(expr.value())}) {
if (!subscripts) {
return std::move(*value);
} else {
@ -314,12 +314,12 @@ std::optional<Constant<T>> Folder<T>::ApplyComponent(
ConstantSubscripts at{structures.lbounds()};
do {
StructureConstructor scalar{structures.At(at)};
if (auto *expr{scalar.Find(component)}) {
if (const Constant<T> *value{UnwrapConstantValue<T>(*expr)}) {
if (std::optional<Expr<SomeType>> expr{scalar.Find(component)}) {
if (const Constant<T> *value{UnwrapConstantValue<T>(expr.value())}) {
if (!array.get()) {
// This technique ensures that character length or derived type
// information is propagated to the array constructor.
auto *typedExpr{UnwrapExpr<Expr<T>>(*expr)};
auto *typedExpr{UnwrapExpr<Expr<T>>(expr.value())};
CHECK(typedExpr);
array = std::make_unique<ArrayConstructor<T>>(*typedExpr);
}

View File

@ -207,7 +207,7 @@ static const semantics::Symbol *FindParentComponent(
return nullptr;
}
static const semantics::DerivedTypeSpec *GetParentTypeSpec(
const semantics::DerivedTypeSpec *GetParentTypeSpec(
const semantics::DerivedTypeSpec &derived) {
if (const semantics::Symbol * parent{FindParentComponent(derived)}) {
return &parent->get<semantics::ObjectEntityDetails>()

View File

@ -0,0 +1,163 @@
! RUN: %S/test_folding.sh %s %t %f18
! Test folding of structure constructors
module m1
type parent_type
integer :: parent_field
end type parent_type
type, extends(parent_type) :: child_type
integer :: child_field
end type child_type
type parent_array_type
integer, dimension(2) :: parent_field
end type parent_array_type
type, extends(parent_array_type) :: child_array_type
integer :: child_field
end type child_array_type
type(child_type), parameter :: child_const1 = child_type(10, 11)
logical, parameter :: test_child1 = child_const1%child_field == 11
logical, parameter :: test_parent = child_const1%parent_field == 10
type(child_type), parameter :: child_const2 = child_type(12, 13)
type(child_type), parameter :: array_var(2) = &
[child_type(14, 15), child_type(16, 17)]
logical, parameter :: test_array_child = array_var(2)%child_field == 17
logical, parameter :: test_array_parent = array_var(2)%parent_field == 16
type array_type
real, dimension(3) :: real_field
end type array_type
type(array_type), parameter :: array_var2 = &
array_type([(real(i*i), i = 1,3)])
logical, parameter :: test_array_var = array_var2%real_field(2) == 4.0
type(child_type), parameter, dimension(2) :: child_const3 = &
[child_type(18, 19), child_type(20, 21)]
integer, dimension(2), parameter :: int_const4 = &
child_const3(:)%parent_field
logical, parameter :: test_child2 = int_const4(1) == 18
type(child_array_type), parameter, dimension(2) :: child_const5 = &
[child_array_type([22, 23], 24), child_array_type([25, 26], 27)]
integer, dimension(2), parameter :: int_const6 = child_const5(:)%parent_field(2)
logical, parameter :: test_child3 = int_const6(1) == 23
type(child_type), parameter :: child_const7 = child_type(28, 29)
type(parent_type), parameter :: parent_const8 = child_const7%parent_type
logical, parameter :: test_child4 = parent_const8%parent_field == 28
type(child_type), parameter :: child_const9 = &
child_type(parent_type(30), 31)
integer, parameter :: int_const10 = child_const9%parent_field
logical, parameter :: test_child5 = int_const10 == 30
end module m1
module m2
type grandparent_type
real :: grandparent_field
end type grandparent_type
type, extends(grandparent_type) :: parent_type
integer :: parent_field
end type parent_type
type, extends(parent_type) :: child_type
real :: child_field
end type child_type
type(child_type), parameter :: child_const1 = child_type(10.0, 11, 12.0)
integer, parameter :: int_const2 = &
child_const1%grandparent_type%grandparent_field
logical, parameter :: test_child1 = int_const2 == 10.0
integer, parameter :: int_const3 = &
child_const1%grandparent_field
logical, parameter :: test_child2 = int_const3 == 10.0
type(child_type), parameter :: child_const4 = &
child_type(parent_type(13.0, 14), 15.0)
integer, parameter :: int_const5 = &
child_const4%grandparent_type%grandparent_field
logical, parameter :: test_child3 = int_const5 == 13.0
type(child_type), parameter :: child_const6 = &
child_type(parent_type(grandparent_type(16.0), 17), 18.0)
integer, parameter :: int_const7 = &
child_const6%grandparent_type%grandparent_field
logical, parameter :: test_child4 = int_const7 == 16.0
integer, parameter :: int_const8 = &
child_const6%grandparent_field
logical, parameter :: test_child5 = int_const8 == 16.0
end module m2
module m3
! tests that use components with default initializations and with the
! components in the structure constructors in a different order from the
! declared order
type parent_type
integer :: parent_field1
real :: parent_field2 = 20.0
logical :: parent_field3
end type parent_type
type, extends(parent_type) :: child_type
real :: child_field1
logical :: child_field2 = .false.
integer :: child_field3
end type child_type
type(child_type), parameter :: child_const1 = &
child_type( &
parent_field2 = 10.0, child_field3 = 11, &
child_field2 = .true., parent_field3 = .false., &
parent_field1 = 12, child_field1 = 13.3)
logical, parameter :: test_child1 = child_const1%child_field1 == 13.3
logical, parameter :: test_child2 = child_const1%child_field2 .eqv. .true.
logical, parameter :: test_child3 = child_const1%child_field3 == 11
logical, parameter :: test_parent1 = child_const1%parent_field1 == 12
logical, parameter :: test_parent2 = child_const1%parent_field2 == 10.0
logical, parameter :: test_parent3 = child_const1%parent_field3 .eqv. .false.
logical, parameter :: test_parent4 = &
child_const1%parent_type%parent_field1 == 12
logical, parameter :: test_parent5 = &
child_const1%parent_type%parent_field2 == 10.0
logical, parameter :: test_parent6 = &
child_const1%parent_type%parent_field3 .eqv. .false.
type(parent_type), parameter ::parent_const1 = child_const1%parent_type
logical, parameter :: test_parent7 = parent_const1%parent_field1 == 12
logical, parameter :: test_parent8 = parent_const1%parent_field2 == 10.0
logical, parameter :: test_parent9 = &
parent_const1%parent_field3 .eqv. .false.
type(child_type), parameter :: child_const2 = &
child_type( &
child_field3 = 14, parent_field3 = .true., &
parent_field1 = 15, child_field1 = 16.6)
logical, parameter :: test_child4 = child_const2%child_field1 == 16.6
logical, parameter :: test_child5 = child_const2%child_field2 .eqv. .false.
logical, parameter :: test_child6 = child_const2%child_field3 == 14
logical, parameter :: test_parent10 = child_const2%parent_field1 == 15
logical, parameter :: test_parent11 = child_const2%parent_field2 == 20.0
logical, parameter :: test_parent12 = child_const2%parent_field3 .eqv. .true.
type(child_type), parameter :: child_const3 = &
child_type(parent_type( &
parent_field2 = 17.7, parent_field3 = .false., parent_field1 = 18), &
child_field2 = .false., child_field1 = 19.9, child_field3 = 21)
logical, parameter :: test_child7 = child_const3%parent_field1 == 18
logical, parameter :: test_child8 = child_const3%parent_field2 == 17.7
logical, parameter :: test_child9 = child_const3%parent_field3 .eqv. .false.
logical, parameter :: test_child10 = child_const3%child_field1 == 19.9
logical, parameter :: test_child11 = child_const3%child_field2 .eqv. .false.
logical, parameter :: test_child12 = child_const3%child_field3 == 21
type(child_type), parameter :: child_const4 = &
child_type(parent_type( &
parent_field3 = .true., parent_field1 = 22), &
child_field1 = 23.4, child_field3 = 24)
logical, parameter :: test_child13 = child_const4%parent_field1 == 22
logical, parameter :: test_child14 = child_const4%parent_field2 == 20.0
logical, parameter :: test_child15 = child_const4%parent_field3 .eqv. .true.
logical, parameter :: test_child16 = child_const4%child_field1 == 23.4
logical, parameter :: test_child17 = child_const4%child_field2 .eqv. .false.
logical, parameter :: test_child18 = child_const4%child_field3 == 24
end module m3