[flang] Lower associate construct

This patch lowers the `associate` construct.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D121239

Co-authored-by: V Donaldson <vdonaldson@nvidia.com>
Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
Valentin Clement 2022-03-08 22:08:02 +01:00
parent 1e016c3bd5
commit a49bf0ac38
No known key found for this signature in database
GPG Key ID: 086D54783C928776
3 changed files with 133 additions and 18 deletions

View File

@ -1220,7 +1220,30 @@ private:
} }
void genFIR(const Fortran::parser::AssociateConstruct &) { void genFIR(const Fortran::parser::AssociateConstruct &) {
TODO(toLocation(), "AssociateConstruct lowering"); Fortran::lower::StatementContext stmtCtx;
Fortran::lower::pft::Evaluation &eval = getEval();
for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
if (auto *stmt = e.getIf<Fortran::parser::AssociateStmt>()) {
if (eval.lowerAsUnstructured())
maybeStartBlock(e.block);
localSymbols.pushScope();
for (const Fortran::parser::Association &assoc :
std::get<std::list<Fortran::parser::Association>>(stmt->t)) {
Fortran::semantics::Symbol &sym =
*std::get<Fortran::parser::Name>(assoc.t).symbol;
const Fortran::lower::SomeExpr &selector =
*sym.get<Fortran::semantics::AssocEntityDetails>().expr();
localSymbols.addSymbol(sym, genAssociateSelector(selector, stmtCtx));
}
} else if (e.getIf<Fortran::parser::EndAssociateStmt>()) {
if (eval.lowerAsUnstructured())
maybeStartBlock(e.block);
stmtCtx.finalize();
localSymbols.popScope();
} else {
genFIR(e);
}
}
} }
void genFIR(const Fortran::parser::BlockConstruct &blockConstruct) { void genFIR(const Fortran::parser::BlockConstruct &blockConstruct) {
@ -1571,10 +1594,6 @@ private:
genFIRBranch(getEval().controlSuccessor->block); genFIRBranch(getEval().controlSuccessor->block);
} }
void genFIR(const Fortran::parser::AssociateStmt &) {
TODO(toLocation(), "AssociateStmt lowering");
}
void genFIR(const Fortran::parser::CaseStmt &) { void genFIR(const Fortran::parser::CaseStmt &) {
TODO(toLocation(), "CaseStmt lowering"); TODO(toLocation(), "CaseStmt lowering");
} }
@ -1587,10 +1606,6 @@ private:
TODO(toLocation(), "ElseStmt lowering"); TODO(toLocation(), "ElseStmt lowering");
} }
void genFIR(const Fortran::parser::EndAssociateStmt &) {
TODO(toLocation(), "EndAssociateStmt lowering");
}
void genFIR(const Fortran::parser::EndDoStmt &) { void genFIR(const Fortran::parser::EndDoStmt &) {
TODO(toLocation(), "EndDoStmt lowering"); TODO(toLocation(), "EndDoStmt lowering");
} }
@ -1604,7 +1619,9 @@ private:
} }
// Nop statements - No code, or code is generated at the construct level. // Nop statements - No code, or code is generated at the construct level.
void genFIR(const Fortran::parser::AssociateStmt &) {} // nop
void genFIR(const Fortran::parser::ContinueStmt &) {} // nop void genFIR(const Fortran::parser::ContinueStmt &) {} // nop
void genFIR(const Fortran::parser::EndAssociateStmt &) {} // nop
void genFIR(const Fortran::parser::EndFunctionStmt &) {} // nop void genFIR(const Fortran::parser::EndFunctionStmt &) {} // nop
void genFIR(const Fortran::parser::EndIfStmt &) {} // nop void genFIR(const Fortran::parser::EndIfStmt &) {} // nop
void genFIR(const Fortran::parser::EndSubroutineStmt &) {} // nop void genFIR(const Fortran::parser::EndSubroutineStmt &) {} // nop

View File

@ -21,31 +21,36 @@
// recursively build the vector of module scopes // recursively build the vector of module scopes
static void moduleNames(const Fortran::semantics::Scope &scope, static void moduleNames(const Fortran::semantics::Scope &scope,
llvm::SmallVector<llvm::StringRef, 2> &result) { llvm::SmallVector<llvm::StringRef> &result) {
if (scope.IsTopLevel()) { if (scope.IsTopLevel())
return; return;
}
moduleNames(scope.parent(), result); moduleNames(scope.parent(), result);
if (scope.kind() == Fortran::semantics::Scope::Kind::Module) if (scope.kind() == Fortran::semantics::Scope::Kind::Module)
if (auto *symbol = scope.symbol()) if (const Fortran::semantics::Symbol *symbol = scope.symbol())
result.emplace_back(toStringRef(symbol->name())); result.emplace_back(toStringRef(symbol->name()));
} }
static llvm::SmallVector<llvm::StringRef, 2> static llvm::SmallVector<llvm::StringRef>
moduleNames(const Fortran::semantics::Symbol &symbol) { moduleNames(const Fortran::semantics::Symbol &symbol) {
const auto &scope = symbol.owner(); const Fortran::semantics::Scope &scope = symbol.owner();
llvm::SmallVector<llvm::StringRef, 2> result; llvm::SmallVector<llvm::StringRef> result;
moduleNames(scope, result); moduleNames(scope, result);
return result; return result;
} }
static llvm::Optional<llvm::StringRef> static llvm::Optional<llvm::StringRef>
hostName(const Fortran::semantics::Symbol &symbol) { hostName(const Fortran::semantics::Symbol &symbol) {
const auto &scope = symbol.owner(); const Fortran::semantics::Scope &scope = symbol.owner();
if (scope.kind() == Fortran::semantics::Scope::Kind::Subprogram) { if (scope.kind() == Fortran::semantics::Scope::Kind::Subprogram) {
assert(scope.symbol() && "subprogram scope must have a symbol"); assert(scope.symbol() && "subprogram scope must have a symbol");
return {toStringRef(scope.symbol()->name())}; return toStringRef(scope.symbol()->name());
} }
if (scope.kind() == Fortran::semantics::Scope::Kind::MainProgram)
// Do not use the main program name, if any, because it may lead to name
// collision with procedures with the same name in other compilation units
// (technically illegal, but all compilers are able to compile and link
// properly these programs).
return llvm::StringRef("");
return {}; return {};
} }

View File

@ -0,0 +1,93 @@
! RUN: bbc -emit-fir -o - %s | FileCheck %s
! CHECK-LABEL: func @_QQmain
program p
! CHECK-DAG: [[I:%[0-9]+]] = fir.alloca i32 {{{.*}}uniq_name = "_QFEi"}
! CHECK-DAG: [[N:%[0-9]+]] = fir.alloca i32 {{{.*}}uniq_name = "_QFEn"}
! CHECK: [[T:%[0-9]+]] = fir.address_of(@_QFEt) : !fir.ref<!fir.array<3xi32>>
integer :: n, foo, t(3)
! CHECK: [[N]]
! CHECK-COUNT-3: fir.coordinate_of [[T]]
n = 100; t(1) = 111; t(2) = 222; t(3) = 333
! CHECK: fir.load [[N]]
! CHECK: addi {{.*}} %c5
! CHECK: fir.store %{{[0-9]*}} to [[B:%[0-9]+]]
! CHECK: [[C:%[0-9]+]] = fir.coordinate_of [[T]]
! CHECK: fir.call @_QPfoo
! CHECK: fir.store %{{[0-9]*}} to [[D:%[0-9]+]]
associate (a => n, b => n+5, c => t(2), d => foo(7))
! CHECK: fir.load [[N]]
! CHECK: addi %{{[0-9]*}}, %c1
! CHECK: fir.store %{{[0-9]*}} to [[N]]
a = a + 1
! CHECK: fir.load [[C]]
! CHECK: addi %{{[0-9]*}}, %c1
! CHECK: fir.store %{{[0-9]*}} to [[C]]
c = c + 1
! CHECK: fir.load [[N]]
! CHECK: addi %{{[0-9]*}}, %c1
! CHECK: fir.store %{{[0-9]*}} to [[N]]
n = n + 1
! CHECK: fir.load [[N]]
! CHECK: fir.embox [[T]]
! CHECK: fir.load [[N]]
! CHECK: fir.load [[B]]
! CHECK: fir.load [[C]]
! CHECK: fir.load [[D]]
print*, n, t, a, b, c, d ! expect: 102 111 223 333 102 105 223 7
end associate
call nest
associate (x=>i)
! CHECK: [[IVAL:%[0-9]+]] = fir.load [[I]] : !fir.ref<i32>
! CHECK: [[TWO:%.*]] = arith.constant 2 : i32
! CHECK: arith.cmpi eq, [[IVAL]], [[TWO]] : i32
! CHECK: ^bb
if (x==2) goto 9
! CHECK: [[IVAL:%[0-9]+]] = fir.load [[I]] : !fir.ref<i32>
! CHECK: [[THREE:%.*]] = arith.constant 3 : i32
! CHECK: arith.cmpi eq, [[IVAL]], [[THREE]] : i32
! CHECK: ^bb
! CHECK: fir.call @_FortranAStopStatementText
! CHECK: fir.unreachable
! CHECK: ^bb
if (x==3) stop 'Halt'
! CHECK: fir.call @_FortranAioOutputAscii
print*, "ok"
9 end associate
end
! CHECK-LABEL: func @_QPfoo
integer function foo(x)
integer x
integer, save :: i = 0
i = i + x
foo = i
end function foo
! CHECK-LABEL: func @_QPnest(
subroutine nest
integer, parameter :: n = 10
integer :: a(5), b(n)
associate (s => sequence(size(a)))
a = s
associate(t => sequence(n))
b = t
! CHECK: cond_br %{{.*}}, [[BB1:\^bb[0-9]]], [[BB2:\^bb[0-9]]]
! CHECK: [[BB1]]:
! CHECK: br [[BB3:\^bb[0-9]]]
! CHECK: [[BB2]]:
if (a(1) > b(1)) goto 9
end associate
a = a * a
end associate
! CHECK: br [[BB3]]
! CHECK: [[BB3]]:
9 print *, sum(a), sum(b) ! expect: 55 55
contains
function sequence(n)
integer sequence(n)
sequence = [(i,i=1,n)]
end function
end subroutine nest