[flang] Lower type-bound procedure call needing dynamic dispatch to fir.dispatch

Lower call with polymorphic entities to fir.dispatch operation. This patch only
focus one lowering with simple scalar polymorphic entities. A follow-up patch
will deal with allocatble, pointer and array of polymorphic entities as they
require box manipulation for the passed-object.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D135649
This commit is contained in:
Valentin Clement 2022-10-12 15:24:21 +02:00
parent b1d7a95e4e
commit 7883900c04
No known key found for this signature in database
GPG Key ID: 086D54783C928776
9 changed files with 279 additions and 18 deletions

View File

@ -284,6 +284,14 @@ public:
/// procedure.
bool isIndirectCall() const;
/// Returns true if this is a call of a type-bound procedure with a
/// polymorphic entity.
bool requireDispatchCall() const;
/// Get the passed-object argument index. nullopt if there is no passed-object
/// index.
std::optional<unsigned> getPassArgIndex() const;
/// Return the procedure symbol if this is a call to a user defined
/// procedure.
const Fortran::semantics::Symbol *getProcedureSymbol() const;
@ -372,6 +380,10 @@ public:
/// called through pointers or not.
bool isIndirectCall() const { return false; }
/// On the callee side it does not matter whether the procedure is called
/// through dynamic dispatch or not.
bool requireDispatchCall() const { return false; };
/// Return the procedure symbol if this is a call to a user defined
/// procedure.
const Fortran::semantics::Symbol *getProcedureSymbol() const;

View File

@ -203,7 +203,7 @@ inline unsigned getRankOfShapeType(mlir::Type t) {
}
/// Get the memory reference type of the data pointer from the box type,
inline mlir::Type boxMemRefType(fir::BoxType t) {
inline mlir::Type boxMemRefType(fir::BaseBoxType t) {
auto eleTy = t.getEleTy();
if (!eleTy.isa<fir::PointerType, fir::HeapType>())
eleTy = fir::ReferenceType::get(t);

View File

@ -88,6 +88,36 @@ bool Fortran::lower::CallerInterface::isIndirectCall() const {
return false;
}
bool Fortran::lower::CallerInterface::requireDispatchCall() const {
// calls with NOPASS attribute still have their component so check if it is
// polymorphic.
if (const Fortran::evaluate::Component *component =
procRef.proc().GetComponent()) {
if (Fortran::semantics::IsPolymorphic(component->GetFirstSymbol()))
return true;
}
// calls with PASS attribute have the passed-object already set in its
// arguments. Just check if their is one.
std::optional<unsigned> passArg = getPassArgIndex();
if (passArg)
return true;
return false;
}
std::optional<unsigned>
Fortran::lower::CallerInterface::getPassArgIndex() const {
unsigned passArgIdx = 0;
std::optional<unsigned> passArg = std::nullopt;
for (const auto &arg : getCallDescription().arguments()) {
if (arg && arg->isPassedObject()) {
passArg = passArgIdx;
break;
}
++passArgIdx;
}
return passArg;
}
const Fortran::semantics::Symbol *
Fortran::lower::CallerInterface::getIfIndirectCallSymbol() const {
if (const Fortran::semantics::Symbol *symbol = procRef.proc().GetSymbol())

View File

@ -1993,8 +1993,10 @@ public:
}
mlir::Value base = fir::getBase(array);
auto seqTy =
fir::dyn_cast_ptrOrBoxEleTy(base.getType()).cast<fir::SequenceType>();
mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(base.getType());
if (auto classTy = eleTy.dyn_cast<fir::ClassType>())
eleTy = classTy.getEleTy();
auto seqTy = eleTy.cast<fir::SequenceType>();
assert(args.size() == seqTy.getDimension());
mlir::Type ty = builder.getRefType(seqTy.getEleTy());
auto addr = builder.create<fir::CoordinateOp>(loc, ty, base, args);
@ -2727,11 +2729,47 @@ public:
if (addHostAssociations)
operands.push_back(converter.hostAssocTupleValue());
auto call = builder.create<fir::CallOp>(loc, funcType.getResults(),
funcSymbolAttr, operands);
mlir::Value callResult;
unsigned callNumResults;
if (caller.requireDispatchCall()) {
// Procedure call requiring a dynamic dispatch. Call is created with
// fir.dispatch.
// Get the raw procedure name. The procedure name is not mangled in the
// binding table.
const auto &ultimateSymbol =
caller.getCallDescription().proc().GetSymbol()->GetUltimate();
auto procName = toStringRef(ultimateSymbol.name());
fir::DispatchOp dispatch;
if (std::optional<unsigned> passArg = caller.getPassArgIndex()) {
// PASS, PASS(arg-name)
dispatch = builder.create<fir::DispatchOp>(
loc, funcType.getResults(), procName, operands[*passArg], operands,
builder.getI32IntegerAttr(*passArg));
} else {
// NOPASS
const Fortran::evaluate::Component *component =
caller.getCallDescription().proc().GetComponent();
assert(component && "expect component for type-bound procedure call.");
fir::ExtendedValue pass =
symMap.lookupSymbol(component->GetFirstSymbol()).toExtendedValue();
dispatch = builder.create<fir::DispatchOp>(loc, funcType.getResults(),
procName, fir::getBase(pass),
operands, nullptr);
}
callResult = dispatch.getResult(0);
callNumResults = dispatch.getNumResults();
} else {
// Standard procedure call with fir.call.
auto call = builder.create<fir::CallOp>(loc, funcType.getResults(),
funcSymbolAttr, operands);
callResult = call.getResult(0);
callNumResults = call.getNumResults();
}
if (caller.mustSaveResult())
builder.create<fir::SaveResultOp>(loc, call.getResult(0),
builder.create<fir::SaveResultOp>(loc, callResult,
fir::getBase(allocatedResult.value()),
arrayResultShape, resultLengths);
@ -2754,7 +2792,7 @@ public:
return mlir::Value{}; // subroutine call
// For now, Fortran return values are implemented with a single MLIR
// function return value.
assert(call.getNumResults() == 1 &&
assert(callNumResults == 1 &&
"Expected exactly one result in FUNCTION call");
// Call a BIND(C) function that return a char.
@ -2764,10 +2802,10 @@ public:
funcType.getResults()[0].dyn_cast<fir::CharacterType>();
mlir::Value len = builder.createIntegerConstant(
loc, builder.getCharacterLengthType(), charTy.getLen());
return fir::CharBoxValue{call.getResult(0), len};
return fir::CharBoxValue{callResult, len};
}
return call.getResult(0);
return callResult;
}
/// Like genExtAddr, but ensure the address returned is a temporary even if \p
@ -6012,7 +6050,7 @@ private:
}
static mlir::Type unwrapBoxEleTy(mlir::Type ty) {
if (auto boxTy = ty.dyn_cast<fir::BoxType>())
if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>())
return fir::unwrapRefType(boxTy.getEleTy());
return ty;
}
@ -7150,7 +7188,7 @@ private:
// Need an intermediate dereference if the boxed value
// appears in the middle of the component path or if it is
// on the right and this is not a pointer assignment.
if (auto boxTy = ty.dyn_cast<fir::BoxType>()) {
if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>()) {
auto currentFunc = components.getExtendCoorRef();
auto loc = getLoc();
auto *bldr = &converter.getFirOpBuilder();
@ -7161,7 +7199,7 @@ private:
deref = true;
}
}
} else if (auto boxTy = ty.dyn_cast<fir::BoxType>()) {
} else if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>()) {
ty = fir::unwrapRefType(boxTy.getEleTy());
auto recTy = ty.cast<fir::RecordType>();
ty = recTy.getType(name);
@ -7247,7 +7285,7 @@ private:
// assignment, then insert the dereference of the box before any
// conversion and store.
if (!isPointerAssignment()) {
if (auto boxTy = eleTy.dyn_cast<fir::BoxType>()) {
if (auto boxTy = eleTy.dyn_cast<fir::BaseBoxType>()) {
eleTy = fir::boxMemRefType(boxTy);
addr = builder.create<fir::BoxAddrOp>(loc, eleTy, addr);
eleTy = fir::unwrapRefType(eleTy);

View File

@ -155,6 +155,10 @@ Fortran::lower::mangle::mangleName(const Fortran::semantics::Symbol &symbol,
llvm::report_fatal_error(
"only derived type instances can be mangled");
},
[&](const Fortran::semantics::ProcBindingDetails &procBinding)
-> std::string {
return mangleName(procBinding.symbol(), keepExternalInScope);
},
[](const auto &) -> std::string { TODO_NOLOC("symbol mangling"); },
},
ultimateSymbol.details());

View File

@ -204,7 +204,7 @@ bool fir::MutableBoxValue::verify() const {
/// Debug verifier for BoxValue ctor. There is no guarantee this will
/// always be called.
bool fir::BoxValue::verify() const {
if (!addr.getType().isa<fir::BoxType>())
if (!addr.getType().isa<fir::BaseBoxType>())
return false;
if (!lbounds.empty() && lbounds.size() != rank())
return false;

View File

@ -460,7 +460,7 @@ mlir::Value fir::FirOpBuilder::createSlice(mlir::Location loc,
mlir::Value fir::FirOpBuilder::createBox(mlir::Location loc,
const fir::ExtendedValue &exv) {
mlir::Value itemAddr = fir::getBase(exv);
if (itemAddr.getType().isa<fir::BoxType>())
if (itemAddr.getType().isa<fir::BaseBoxType>())
return itemAddr;
auto elementType = fir::dyn_cast_ptrEleTy(itemAddr.getType());
if (!elementType) {
@ -741,7 +741,7 @@ static llvm::SmallVector<mlir::Value> getFromBox(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Type valTy,
mlir::Value boxVal) {
if (auto boxTy = valTy.dyn_cast<fir::BoxType>()) {
if (auto boxTy = valTy.dyn_cast<fir::BaseBoxType>()) {
auto eleTy = fir::unwrapAllRefAndSeqType(boxTy.getEleTy());
if (auto recTy = eleTy.dyn_cast<fir::RecordType>()) {
if (recTy.getNumLenParams() > 0) {
@ -795,7 +795,7 @@ llvm::SmallVector<mlir::Value>
fir::factory::getTypeParams(mlir::Location loc, fir::FirOpBuilder &builder,
fir::ArrayLoadOp load) {
mlir::Type memTy = load.getMemref().getType();
if (auto boxTy = memTy.dyn_cast<fir::BoxType>())
if (auto boxTy = memTy.dyn_cast<fir::BaseBoxType>())
return getFromBox(loc, builder, boxTy, load.getMemref());
return load.getTypeparams();
}

View File

@ -917,7 +917,8 @@ mlir::LogicalResult fir::ConvertOp::verify() {
(inType.isa<fir::BoxType>() && outType.isa<fir::BoxType>()) ||
(inType.isa<fir::BoxProcType>() && outType.isa<fir::BoxProcType>()) ||
(fir::isa_complex(inType) && fir::isa_complex(outType)) ||
(fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)))
(fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) ||
(fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)))
return mlir::success();
return emitOpError("invalid type conversion");
}

View File

@ -0,0 +1,176 @@
! RUN: bbc -polymorphic-type -emit-fir %s -o - | FileCheck %s
! Tests the different possible type involving polymorphic entities.
module call_dispatch
interface
subroutine nopass_defferred(x)
real :: x(:)
end subroutine
end interface
type p1
integer :: a
integer :: b
contains
procedure, nopass :: tbp_nopass
procedure :: tbp_pass
procedure, pass(this) :: tbp_pass_arg0
procedure, pass(this) :: tbp_pass_arg1
procedure, nopass :: proc1 => p1_proc1_nopass
procedure :: proc2 => p1_proc2
procedure, pass(this) :: proc3 => p1_proc3_arg0
procedure, pass(this) :: proc4 => p1_proc4_arg1
procedure, nopass :: p1_fct1_nopass
procedure :: p1_fct2
procedure, pass(this) :: p1_fct3_arg0
procedure, pass(this) :: p1_fct4_arg1
end type
type, abstract :: a1
real :: a
real :: b
contains
procedure(nopass_defferred), deferred, nopass :: nopassd
end type
contains
! ------------------------------------------------------------------------------
! Test lowering of type-bound procedure call on polymorphic entities
! ------------------------------------------------------------------------------
function p1_fct1_nopass()
real :: p1_fct1_nopass
end function
! CHECK-LABEL: func.func @_QMcall_dispatchPp1_fct1_nopass() -> f32
function p1_fct2(p)
real :: p1_fct2
class(p1) :: p
end function
! CHECK-LABEL: func.func @_QMcall_dispatchPp1_fct2(%{{.*}}: !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) -> f32
function p1_fct3_arg0(this)
real :: p1_fct2
class(p1) :: this
end function
! CHECK-LABEL: func.func @_QMcall_dispatchPp1_fct3_arg0(%{{.*}}: !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) -> f32
function p1_fct4_arg1(i, this)
real :: p1_fct2
integer :: i
class(p1) :: this
end function
! CHECK-LABEL: func.func @_QMcall_dispatchPp1_fct4_arg1(%{{.*}}: !fir.ref<i32>, %{{.*}}: !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) -> f32
subroutine p1_proc1_nopass()
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPp1_proc1_nopass()
subroutine p1_proc2(p)
class(p1) :: p
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPp1_proc2(%{{.*}}: !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>)
subroutine p1_proc3_arg0(this)
class(p1) :: this
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPp1_proc3_arg0(%{{.*}}: !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>)
subroutine p1_proc4_arg1(i, this)
integer, intent(in) :: i
class(p1) :: this
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPp1_proc4_arg1(%{{.*}}: !fir.ref<i32>, %{{.*}}: !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>)
subroutine tbp_nopass()
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPtbp_nopass()
subroutine tbp_pass(t)
class(p1) :: t
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPtbp_pass(%{{.*}}: !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>)
subroutine tbp_pass_arg0(this)
class(p1) :: this
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPtbp_pass_arg0(%{{.*}}: !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>)
subroutine tbp_pass_arg1(i, this)
integer, intent(in) :: i
class(p1) :: this
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPtbp_pass_arg1(%{{.*}}: !fir.ref<i32>, %{{.*}}: !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>)
subroutine check_dispatch(p)
class(p1) :: p
real :: a
call p%tbp_nopass()
call p%tbp_pass()
call p%tbp_pass_arg0()
call p%tbp_pass_arg1(1)
call p%proc1()
call p%proc2()
call p%proc3()
call p%proc4(1)
a = p%p1_fct1_nopass()
a = p%p1_fct2()
a = p%p1_fct3_arg0()
a = p%p1_fct4_arg1(1)
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPcheck_dispatch(
! CHECK-SAME: %[[P:.*]]: !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>> {fir.bindc_name = "p"}) {
! CHECK: fir.dispatch "tbp_nopass"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>){{$}}
! CHECK: fir.dispatch "tbp_pass"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) (%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) {pass_arg_pos = 0 : i32}
! CHECK: fir.dispatch "tbp_pass_arg0"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) (%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) {pass_arg_pos = 0 : i32}
! CHECK: fir.dispatch "tbp_pass_arg1"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) (%{{.*}}, %[[P]] : !fir.ref<i32>, !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
! CHECK: fir.dispatch "proc1"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>){{$}}
! CHECK: fir.dispatch "proc2"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) (%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) {pass_arg_pos = 0 : i32}
! CHECK: fir.dispatch "proc3"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) (%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) {pass_arg_pos = 0 : i32}
! CHECK: fir.dispatch "proc4"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) (%{{.*}}, %[[P]] : !fir.ref<i32>, !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
! CHECK: %{{.*}} = fir.dispatch "p1_fct1_nopass"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) -> f32{{$}}
! CHECK: %{{.*}} = fir.dispatch "p1_fct2"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) (%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) -> f32 {pass_arg_pos = 0 : i32}
! CHECK: %{{.*}} = fir.dispatch "p1_fct3_arg0"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) (%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) -> f32 {pass_arg_pos = 0 : i32}
! CHECK: %{{.*}} = fir.dispatch "p1_fct4_arg1"(%[[P]] : !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) (%{{.*}}, %[[P]] : !fir.ref<i32>, !fir.class<!fir.type<_QMcall_dispatchTp1{a:i32,b:i32}>>) -> f32 {pass_arg_pos = 1 : i32}
subroutine check_dispatch_deferred(a, x)
class(a1) :: a
real :: x(:)
call a%nopassd(x)
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPcheck_dispatch_deferred(
! CHECK-SAME: %[[ARG0:.*]]: !fir.class<!fir.type<_QMcall_dispatchTa1{a:f32,b:f32}>> {fir.bindc_name = "a"},
! CHECK-SAME: %[[ARG1:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
! CHECK: fir.dispatch "nopassd"(%[[ARG0]] : !fir.class<!fir.type<_QMcall_dispatchTa1{a:f32,b:f32}>>) (%[[ARG1]] : !fir.box<!fir.array<?xf32>>)
! ------------------------------------------------------------------------------
! Test that direct call is emitted when the type is known
! ------------------------------------------------------------------------------
subroutine check_nodispatch(t)
type(p1) :: t
call t%tbp_nopass()
call t%tbp_pass()
call t%tbp_pass_arg0()
call t%tbp_pass_arg1(1)
end subroutine
! CHECK-LABEL: func.func @_QMcall_dispatchPcheck_nodispatch
! CHECK: fir.call @_QMcall_dispatchPtbp_nopass
! CHECK: fir.call @_QMcall_dispatchPtbp_pass
! CHECK: fir.call @_QMcall_dispatchPtbp_pass_arg0
! CHECK: fir.call @_QMcall_dispatchPtbp_pass_arg1
end module