[flang] more conversions to Traverse

Original-commit: flang-compiler/f18@e8668e2368
Reviewed-on: https://github.com/flang-compiler/f18/pull/755
Tree-same-pre-rewrite: false
This commit is contained in:
peter klausler 2019-09-19 14:56:12 -07:00
parent 48fd773a19
commit f07d6bc6ba
5 changed files with 126 additions and 113 deletions

View File

@ -25,6 +25,7 @@
#include "intrinsics-library-templates.h"
#include "shape.h"
#include "tools.h"
#include "traverse.h"
#include "type.h"
#include "../common/indirection.h"
#include "../common/template.h"
@ -1930,15 +1931,17 @@ auto ApplyElementwise(
// Predicate: is a scalar expression suitable for naive scalar expansion
// in the flattening of an array expression?
// TODO: capture such scalar expansions in temporaries, flatten everything
struct UnexpandabilityFindingVisitor : public virtual VisitorBase<bool> {
using Result = bool;
explicit UnexpandabilityFindingVisitor(int) { result() = false; }
template<typename T> void Handle(FunctionRef<T> &) { Return(true); }
template<typename T> void Handle(CoarrayRef &) { Return(true); }
struct UnexpandabilityFindingVisitor
: public AnyTraverse<UnexpandabilityFindingVisitor> {
using Base = AnyTraverse<UnexpandabilityFindingVisitor>;
using Base::operator();
UnexpandabilityFindingVisitor() : Base{*this} {}
template<typename T> bool operator()(const FunctionRef<T> &) { return true; }
bool operator()(const CoarrayRef &) { return true; }
};
template<typename T> bool IsExpandableScalar(const Expr<T> &expr) {
return !Visitor<UnexpandabilityFindingVisitor>{0}.Traverse(expr);
return !UnexpandabilityFindingVisitor{}(expr);
}
template<typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>

View File

@ -15,7 +15,6 @@
#include "shape.h"
#include "fold.h"
#include "tools.h"
#include "traverse.h"
#include "type.h"
#include "../common/idioms.h"
#include "../common/template.h"
@ -347,65 +346,63 @@ Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
}
}
void GetShapeVisitor::Handle(const Symbol &symbol) {
std::visit(
auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
return std::visit(
common::visitors{
[&](const semantics::ObjectEntityDetails &) {
Handle(NamedEntity{symbol});
return (*this)(NamedEntity{symbol});
},
[&](const semantics::AssocEntityDetails &assoc) {
Nested(assoc.expr());
return (*this)(assoc.expr());
},
[&](const semantics::SubprogramDetails &subp) {
if (subp.isFunction()) {
Handle(subp.result());
return (*this)(subp.result());
} else {
Return();
return Result{};
}
},
[&](const semantics::ProcBindingDetails &binding) {
Handle(binding.symbol());
return (*this)(binding.symbol());
},
[&](const semantics::UseDetails &use) {
return (*this)(use.symbol());
},
[&](const semantics::UseDetails &use) { Handle(use.symbol()); },
[&](const semantics::HostAssocDetails &assoc) {
Handle(assoc.symbol());
return (*this)(assoc.symbol());
},
[&](const auto &) { Return(); },
[&](const auto &) { return Result{}; },
},
symbol.details());
}
void GetShapeVisitor::Handle(const Component &component) {
auto GetShapeHelper::operator()(const Component &component) const -> Result {
if (component.GetLastSymbol().Rank() > 0) {
Handle(NamedEntity{Component{component}});
return (*this)(NamedEntity{Component{component}});
} else {
Nested(component.base());
return (*this)(component.base());
}
}
void GetShapeVisitor::Handle(const NamedEntity &base) {
auto GetShapeHelper::operator()(const NamedEntity &base) const -> Result {
const Symbol &symbol{base.GetLastSymbol()};
if (const auto *object{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
if (IsImpliedShape(symbol)) {
Nested(object->init());
return (*this)(object->init());
} else {
Shape result;
int n{object->shape().Rank()};
for (int dimension{0}; dimension < n; ++dimension) {
result.emplace_back(GetExtent(context_, base, dimension));
}
Return(std::move(result));
return std::move(result);
}
} else {
Handle(symbol);
return (*this)(symbol);
}
}
void GetShapeVisitor::Handle(const Substring &substring) {
Nested(substring.parent());
}
void GetShapeVisitor::Handle(const ArrayRef &arrayRef) {
auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
Shape shape;
int dimension{0};
for (const Subscript &ss : arrayRef.subscript()) {
@ -415,13 +412,13 @@ void GetShapeVisitor::Handle(const ArrayRef &arrayRef) {
++dimension;
}
if (shape.empty()) {
Nested(arrayRef.base());
return (*this)(arrayRef.base());
} else {
Return(std::move(shape));
return std::move(shape);
}
}
void GetShapeVisitor::Handle(const CoarrayRef &coarrayRef) {
auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
Shape shape;
NamedEntity base{coarrayRef.GetBase()};
int dimension{0};
@ -432,45 +429,48 @@ void GetShapeVisitor::Handle(const CoarrayRef &coarrayRef) {
++dimension;
}
if (shape.empty()) {
Nested(base);
return (*this)(base);
} else {
Return(std::move(shape));
return std::move(shape);
}
}
void GetShapeVisitor::Handle(const ProcedureRef &call) {
auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
return (*this)(substring.parent());
}
auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
if (call.Rank() == 0) {
Scalar();
return Scalar();
} else if (call.IsElemental()) {
for (const auto &arg : call.arguments()) {
if (arg.has_value() && arg->Rank() > 0) {
Nested(*arg);
return;
return (*this)(*arg);
}
}
Scalar();
return Scalar();
} else if (const Symbol * symbol{call.proc().GetSymbol()}) {
Handle(*symbol);
return (*this)(*symbol);
} else if (const auto *intrinsic{
std::get_if<SpecificIntrinsic>(&call.proc().u)}) {
if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
intrinsic->name == "ubound") {
const auto *expr{call.arguments().front().value().UnwrapExpr()};
CHECK(expr != nullptr);
Return(Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}});
return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}};
} else if (intrinsic->name == "reshape") {
if (call.arguments().size() >= 2 && call.arguments().at(1).has_value()) {
// SHAPE(RESHAPE(array,shape)) -> shape
const auto *shapeExpr{call.arguments().at(1).value().UnwrapExpr()};
CHECK(shapeExpr != nullptr);
Expr<SomeInteger> shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
Return(AsShape(context_, ConvertToType<ExtentType>(std::move(shape))));
return AsShape(context_, ConvertToType<ExtentType>(std::move(shape)));
}
} else {
// TODO: shapes of other non-elemental intrinsic results
}
}
Return();
return std::nullopt;
}
bool CheckConformance(parser::ContextualMessages &messages, const Shape &left,
@ -500,5 +500,4 @@ bool CheckConformance(parser::ContextualMessages &messages, const Shape &left,
}
return true;
}
}

View File

@ -20,7 +20,7 @@
#include "expression.h"
#include "tools.h"
#include "traversal.h"
#include "traverse.h"
#include "type.h"
#include "variable.h"
#include "../common/indirection.h"
@ -86,56 +86,57 @@ MaybeExtentExpr GetSize(Shape &&);
// Utility predicate: does an expression reference any implied DO index?
bool ContainsAnyImpliedDoIndex(const ExtentExpr &);
// Compilation-time shape conformance checking, when corresponding extents
// are known.
bool CheckConformance(parser::ContextualMessages &, const Shape &,
const Shape &, const char * = "left operand",
const char * = "right operand");
class GetShapeVisitor : public virtual VisitorBase<std::optional<Shape>> {
// GetShape()
class GetShapeHelper
: public AnyTraverse<GetShapeHelper, std::optional<Shape>> {
public:
using Result = std::optional<Shape>;
explicit GetShapeVisitor(FoldingContext &c) : context_{c} {}
using Base = AnyTraverse<GetShapeHelper, Result>;
using Base::operator();
GetShapeHelper(FoldingContext &c) : Base{*this}, context_{c} {}
template<typename T> void Handle(const Constant<T> &c) {
Return(AsShape(c.SHAPE()));
Result operator()(const ImpliedDoIndex &) const { return Scalar(); }
Result operator()(const DescriptorInquiry &) const { return Scalar(); }
template<int KIND> Result operator()(const TypeParamInquiry<KIND> &) const {
return Scalar();
}
void Handle(const Symbol &);
void Handle(const Component &);
void Handle(const NamedEntity &);
void Handle(const StaticDataObject::Pointer &) { Scalar(); }
void Handle(const ArrayRef &);
void Handle(const CoarrayRef &);
void Handle(const Substring &);
void Handle(const ProcedureRef &);
void Handle(const StructureConstructor &) { Scalar(); }
template<typename T> void Handle(const ArrayConstructor<T> &aconst) {
Return(Shape{GetArrayConstructorExtent(aconst)});
Result operator()(const BOZLiteralConstant &) const { return Scalar(); }
Result operator()(const StaticDataObject::Pointer &) const {
return Scalar();
}
Result operator()(const StructureConstructor &) const { return Scalar(); }
template<typename T> Result operator()(const Constant<T> &c) const {
return AsShape(c.SHAPE());
}
Result operator()(const Symbol &) const;
Result operator()(const Component &) const;
Result operator()(const NamedEntity &) const;
Result operator()(const ArrayRef &) const;
Result operator()(const CoarrayRef &) const;
Result operator()(const Substring &) const;
Result operator()(const ProcedureRef &) const;
template<typename T>
Result operator()(const ArrayConstructor<T> &aconst) const {
return Shape{GetArrayConstructorExtent(aconst)};
}
void Handle(const ImpliedDoIndex &) { Scalar(); }
void Handle(const DescriptorInquiry &) { Scalar(); }
template<int KIND> void Handle(const TypeParamInquiry<KIND> &) { Scalar(); }
void Handle(const BOZLiteralConstant &) { Scalar(); }
void Handle(const NullPointer &) { Return(); }
template<typename D, typename R, typename LO, typename RO>
void Handle(const Operation<D, R, LO, RO> &operation) {
Result operator()(const Operation<D, R, LO, RO> &operation) const {
if (operation.right().Rank() > 0) {
Nested(operation.right());
(*this)(operation.right());
} else {
Nested(operation.left());
(*this)(operation.left());
}
}
private:
void Scalar() { Return(Shape{}); }
template<typename A> void Nested(const A &x) {
Return(GetShape(context_, x));
}
static Result Scalar() { return Shape{}; }
template<typename T>
MaybeExtentExpr GetArrayConstructorValueExtent(
const ArrayConstructorValue<T> &value) {
const ArrayConstructorValue<T> &value) const {
return std::visit(
common::visitors{
[&](const Expr<T> &x) -> MaybeExtentExpr {
@ -165,7 +166,7 @@ private:
template<typename T>
MaybeExtentExpr GetArrayConstructorExtent(
const ArrayConstructorValues<T> &values) {
const ArrayConstructorValues<T> &values) const {
ExtentExpr result{0};
for (const auto &value : values) {
if (MaybeExtentExpr n{GetArrayConstructorValueExtent(value)}) {
@ -182,7 +183,14 @@ private:
template<typename A>
std::optional<Shape> GetShape(FoldingContext &context, const A &x) {
return Visitor<GetShapeVisitor>{context}.Traverse(x);
return GetShapeHelper{context}(x);
}
// Compilation-time shape conformance checking, when corresponding extents
// are known.
bool CheckConformance(parser::ContextualMessages &, const Shape &,
const Shape &, const char * = "left operand",
const char * = "right operand");
}
#endif // FORTRAN_EVALUATE_SHAPE_H_

View File

@ -25,12 +25,10 @@ using namespace Fortran::parser::literals;
namespace Fortran::evaluate {
// IsVariable()
void IsVariableVisitor::Handle(const ProcedureDesignator &x) {
if (const semantics::Symbol * symbol{x.GetSymbol()}) {
Return(symbol->attrs().test(semantics::Attr::POINTER));
} else {
Return(false);
}
auto IsVariableHelper::operator()(const ProcedureDesignator &x) const
-> Result {
const semantics::Symbol *symbol{x.GetSymbol()};
return symbol && symbol->attrs().test(semantics::Attr::POINTER);
}
// Conversions of complex component expressions to REAL.

View File

@ -18,7 +18,9 @@
#include "constant.h"
#include "expression.h"
#include "traversal.h"
#include "traverse.h"
#include "../common/idioms.h"
#include "../common/template.h"
#include "../common/unwrap.h"
#include "../parser/message.h"
#include "../semantics/attr.h"
@ -26,6 +28,7 @@
#include <array>
#include <optional>
#include <set>
#include <type_traits>
#include <utility>
namespace Fortran::evaluate {
@ -61,36 +64,38 @@ std::optional<Variable<A>> AsVariable(const std::optional<Expr<A>> &expr) {
// operation. Be advised: a call to a function that returns an object
// pointer is a "variable" in Fortran (it can be the left-hand side of
// an assignment).
struct IsVariableVisitor : public virtual VisitorBase<std::optional<bool>> {
// std::optional<> is used because it is default-constructible.
using Result = std::optional<bool>;
explicit IsVariableVisitor(std::nullptr_t) {}
void Handle(const StaticDataObject &) { Return(false); }
void Handle(const Symbol &) { Return(true); }
void Pre(const Component &) { Return(true); }
void Pre(const ArrayRef &) { Return(true); }
void Pre(const CoarrayRef &) { Return(true); }
void Pre(const ComplexPart &) { Return(true); }
void Handle(const ProcedureDesignator &);
template<TypeCategory CAT, int KIND>
void Pre(const Expr<Type<CAT, KIND>> &x) {
if (!std::holds_alternative<Designator<Type<CAT, KIND>>>(x.u) &&
!std::holds_alternative<FunctionRef<Type<CAT, KIND>>>(x.u)) {
Return(false);
struct IsVariableHelper
: public AnyTraverse<IsVariableHelper, std::optional<bool>> {
using Result = std::optional<bool>; // effectively tri-state
using Base = AnyTraverse<IsVariableHelper, Result>;
IsVariableHelper() : Base{*this} {}
using Base::operator();
Result operator()(const StaticDataObject &) const { return false; }
Result operator()(const Symbol &) const { return true; }
Result operator()(const Component &) const { return true; }
Result operator()(const ArrayRef &) const { return true; }
Result operator()(const CoarrayRef &) const { return true; }
Result operator()(const ComplexPart &) const { return true; }
Result operator()(const ProcedureDesignator &) const;
template<typename T> Result operator()(const Expr<T> &x) const {
if constexpr (common::HasMember<T, AllIntrinsicTypes> ||
std::is_same_v<T, SomeDerived>) {
// Expression with a specific type
if (std::holds_alternative<Designator<T>>(x.u) ||
std::holds_alternative<FunctionRef<T>>(x.u)) {
if (auto known{(*this)(x.u)}) {
return known;
}
}
return false;
} else {
return (*this)(x.u);
}
}
void Pre(const Expr<SomeDerived> &x) {
if (!std::holds_alternative<Designator<SomeDerived>>(x.u) &&
!std::holds_alternative<FunctionRef<SomeDerived>>(x.u)) {
Return(false);
}
}
template<typename A> void Post(const A &) { Return(false); }
};
template<typename A> bool IsVariable(const A &x) {
Visitor<IsVariableVisitor> visitor{nullptr};
if (auto optional{visitor.Traverse(x)}) {
if (std::optional<bool> optional{IsVariableHelper{}(x)}) {
return *optional;
} else {
return false;