[flang] Fold CSHIFT

Implement folding of the transformational intrinsic function
CSHIFT for all types.

Differential Revision: https://reviews.llvm.org/D108931
This commit is contained in:
peter klausler 2021-07-07 10:52:09 -07:00
parent db9de22f2b
commit 0bbb2d0036
9 changed files with 123 additions and 20 deletions

View File

@ -992,6 +992,23 @@ private:
std::optional<ConstantSubscripts> lbounds_; std::optional<ConstantSubscripts> lbounds_;
}; };
// Given a collection of element values, package them as a Constant.
// If the type is Character or a derived type, take the length or type
// (resp.) from a another Constant.
template <typename T>
Constant<T> PackageConstant(std::vector<Scalar<T>> &&elements,
const Constant<T> &reference, const ConstantSubscripts &shape) {
if constexpr (T::category == TypeCategory::Character) {
return Constant<T>{
reference.LEN(), std::move(elements), ConstantSubscripts{shape}};
} else if constexpr (T::category == TypeCategory::Derived) {
return Constant<T>{reference.GetType().GetDerivedTypeSpec(),
std::move(elements), ConstantSubscripts{shape}};
} else {
return Constant<T>{std::move(elements), ConstantSubscripts{shape}};
}
}
} // namespace Fortran::evaluate } // namespace Fortran::evaluate
namespace Fortran::semantics { namespace Fortran::semantics {

View File

@ -102,8 +102,7 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
CharacterUtils<KIND>::TRIM(std::get<Scalar<T>>(*scalar))}}; CharacterUtils<KIND>::TRIM(std::get<Scalar<T>>(*scalar))}};
} }
} }
// TODO: cshift, eoshift, maxloc, minloc, pack, spread, transfer, // TODO: findloc, maxloc, minloc, transfer
// transpose, unpack
return Expr<T>{std::move(funcRef)}; return Expr<T>{std::move(funcRef)};
} }

View File

@ -60,8 +60,7 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
} else if (name == "sum") { } else if (name == "sum") {
return FoldSum<T>(context, std::move(funcRef)); return FoldSum<T>(context, std::move(funcRef));
} }
// TODO: cshift, dot_product, eoshift, matmul, pack, spread, transfer, // TODO: dot_product, matmul, transfer
// transpose, unpack
return Expr<T>{std::move(funcRef)}; return Expr<T>{std::move(funcRef)};
} }

View File

@ -60,7 +60,9 @@ public:
std::optional<Constant<T>> Folding(ArrayRef &); std::optional<Constant<T>> Folding(ArrayRef &);
Expr<T> Folding(Designator<T> &&); Expr<T> Folding(Designator<T> &&);
Constant<T> *Folding(std::optional<ActualArgument> &); Constant<T> *Folding(std::optional<ActualArgument> &);
Expr<T> Reshape(FunctionRef<T> &&);
Expr<T> CSHIFT(FunctionRef<T> &&);
Expr<T> RESHAPE(FunctionRef<T> &&);
private: private:
FoldingContext &context_; FoldingContext &context_;
@ -546,7 +548,78 @@ template <typename T> Expr<T> MakeInvalidIntrinsic(FunctionRef<T> &&funcRef) {
ActualArguments{std::move(funcRef.arguments())}}}; ActualArguments{std::move(funcRef.arguments())}}};
} }
template <typename T> Expr<T> Folder<T>::Reshape(FunctionRef<T> &&funcRef) { template <typename T> Expr<T> Folder<T>::CSHIFT(FunctionRef<T> &&funcRef) {
auto args{funcRef.arguments()};
CHECK(args.size() == 3);
const auto *array{UnwrapConstantValue<T>(args[0])};
const auto *shiftExpr{UnwrapExpr<Expr<SomeInteger>>(args[1])};
auto dim{GetInt64ArgOr(args[2], 1)};
if (!array || !shiftExpr || !dim) {
return Expr<T>{std::move(funcRef)};
}
auto convertedShift{Fold(context_,
ConvertToType<SubscriptInteger>(Expr<SomeInteger>{*shiftExpr}))};
const auto *shift{UnwrapConstantValue<SubscriptInteger>(convertedShift)};
if (!shift) {
return Expr<T>{std::move(funcRef)};
}
// Arguments are constant
if (*dim < 1 || *dim > array->Rank()) {
context_.messages().Say("Invalid 'dim=' argument (%jd) in CSHIFT"_err_en_US,
static_cast<std::intmax_t>(*dim));
} else if (shift->Rank() > 0 && shift->Rank() != array->Rank() - 1) {
// message already emitted from intrinsic look-up
} else {
int rank{array->Rank()};
int zbDim{static_cast<int>(*dim) - 1};
bool ok{true};
if (shift->Rank() > 0) {
int k{0};
for (int j{0}; j < rank; ++j) {
if (j != zbDim) {
if (array->shape()[j] != shift->shape()[k]) {
context_.messages().Say(
"Invalid 'shift=' argument in CSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US,
k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
static_cast<std::intmax_t>(array->shape()[j]));
ok = false;
}
++k;
}
}
}
if (ok) {
std::vector<Scalar<T>> resultElements;
ConstantSubscripts arrayAt{array->lbounds()};
ConstantSubscript dimLB{arrayAt[zbDim]};
ConstantSubscript dimExtent{array->shape()[zbDim]};
ConstantSubscripts shiftAt{shift->lbounds()};
for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) {
ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
ConstantSubscript zbDimIndex{shiftCount % dimExtent};
if (zbDimIndex < 0) {
zbDimIndex += dimExtent;
}
for (ConstantSubscript j{0}; j < dimExtent; ++j) {
arrayAt[zbDim] = dimLB + zbDimIndex;
resultElements.push_back(array->At(arrayAt));
if (++zbDimIndex == dimExtent) {
zbDimIndex = 0;
}
}
arrayAt[zbDim] = dimLB + dimExtent - 1;
array->IncrementSubscripts(arrayAt);
shift->IncrementSubscripts(shiftAt);
}
return Expr<T>{PackageConstant<T>(
std::move(resultElements), *array, array->shape())};
}
}
// Invalid, prevent re-folding
return MakeInvalidIntrinsic(std::move(funcRef));
}
template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
auto args{funcRef.arguments()}; auto args{funcRef.arguments()};
CHECK(args.size() == 4); CHECK(args.size() == 4);
const auto *source{UnwrapConstantValue<T>(args[0])}; const auto *source{UnwrapConstantValue<T>(args[0])};
@ -679,10 +752,13 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
} }
if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) { if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
const std::string name{intrinsic->name}; const std::string name{intrinsic->name};
if (name == "reshape") { if (name == "cshift") {
return Folder<T>{context}.Reshape(std::move(funcRef)); return Folder<T>{context}.CSHIFT(std::move(funcRef));
} else if (name == "reshape") {
return Folder<T>{context}.RESHAPE(std::move(funcRef));
} }
// TODO: other type independent transformationals // TODO: eoshift, pack, spread, unpack, transpose
// TODO: extends_type_of, same_type_as
if constexpr (!std::is_same_v<T, SomeDerived>) { if constexpr (!std::is_same_v<T, SomeDerived>) {
return FoldIntrinsicFunction(context, std::move(funcRef)); return FoldIntrinsicFunction(context, std::move(funcRef));
} }

View File

@ -689,10 +689,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else if (name == "ubound") { } else if (name == "ubound") {
return UBOUND(context, std::move(funcRef)); return UBOUND(context, std::move(funcRef));
} }
// TODO: // TODO: count(w/ dim), dot_product, findloc, ibits, image_status, ishftc,
// cshift, dot_product, eoshift, findloc, ibits, image_status, ishftc, // matmul, maxloc, minloc, sign, transfer
// matmul, maxloc, minloc, not, pack, sign, spread, transfer, transpose,
// unpack
return Expr<T>{std::move(funcRef)}; return Expr<T>{std::move(funcRef)};
} }

View File

@ -125,10 +125,9 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
name == "__builtin_ieee_support_underflow_control") { name == "__builtin_ieee_support_underflow_control") {
return Expr<T>{true}; return Expr<T>{true};
} }
// TODO: btest, cshift, dot_product, eoshift, is_iostat_end, // TODO: btest, dot_product, eoshift, is_iostat_end,
// is_iostat_eor, lge, lgt, lle, llt, logical, matmul, out_of_range, // is_iostat_eor, lge, lgt, lle, llt, logical, matmul, out_of_range,
// pack, parity, spread, transfer, transpose, unpack, extends_type_of, // parity, transfer
// same_type_as
return Expr<T>{std::move(funcRef)}; return Expr<T>{std::move(funcRef)};
} }

View File

@ -135,9 +135,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
} else if (name == "tiny") { } else if (name == "tiny") {
return Expr<T>{Scalar<T>::TINY()}; return Expr<T>{Scalar<T>::TINY()};
} }
// TODO: cshift, dim, dot_product, eoshift, fraction, matmul, // TODO: dim, dot_product, fraction, matmul,
// maxloc, minloc, modulo, nearest, norm2, pack, rrspacing, scale, // maxloc, minloc, modulo, nearest, norm2, rrspacing, scale,
// set_exponent, spacing, spread, transfer, transpose, unpack, // set_exponent, spacing, transfer,
// bessel_jn (transformational) and bessel_yn (transformational) // bessel_jn (transformational) and bessel_yn (transformational)
return Expr<T>{std::move(funcRef)}; return Expr<T>{std::move(funcRef)};
} }

View File

@ -20,4 +20,3 @@ character(*), parameter :: zero_sized(*) = input(2:1:1) // 'abcde'
logical, parameter :: test_zero_sized = len(zero_sized).eq.6 logical, parameter :: test_zero_sized = len(zero_sized).eq.6
end end

View File

@ -0,0 +1,16 @@
! RUN: %S/test_folding.sh %s %t %flang_fc1
! REQUIRES: shell
! Tests folding of CSHIFT (valid cases)
module m
integer, parameter :: arr(2,3) = reshape([1, 2, 3, 4, 5, 6], shape(arr))
logical, parameter :: test_sanity = all([arr] == [1, 2, 3, 4, 5, 6])
logical, parameter :: test_cshift_0 = all(cshift([1, 2, 3], 0) == [1, 2, 3])
logical, parameter :: test_cshift_1 = all(cshift([1, 2, 3], 1) == [2, 3, 1])
logical, parameter :: test_cshift_2 = all(cshift([1, 2, 3], 3) == [1, 2, 3])
logical, parameter :: test_cshift_3 = all(cshift([1, 2, 3], 4) == [2, 3, 1])
logical, parameter :: test_cshift_4 = all(cshift([1, 2, 3], -1) == [3, 1, 2])
logical, parameter :: test_cshift_5 = all([cshift(arr, 1, dim=1)] == [2, 1, 4, 3, 6, 5])
logical, parameter :: test_cshift_6 = all([cshift(arr, 1, dim=2)] == [3, 5, 1, 4, 6, 2])
logical, parameter :: test_cshift_7 = all([cshift(arr, [1, 2, 3])] == [2, 1, 3, 4, 6, 5])
logical, parameter :: test_cshift_8 = all([cshift(arr, [1, 2], dim=2)] == [3, 5, 1, 6, 2, 4])
end module