From b0eef1eef0500315bf74721dda3d7a8e3c6a6eac Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Mon, 11 Oct 2021 10:09:31 +0200 Subject: [PATCH] [fir] Add the abstract result conversion pass Add pass that convert abstract result to function argument. This pass is needed before the conversion to LLVM IR. This patch is part of the upstreaming effort from fir-dev branch. Reviewed By: schweitz Differential Revision: https://reviews.llvm.org/D111146 Co-authored-by: Eric Schweitz --- .../flang/Optimizer/Transforms/Passes.h | 1 + .../flang/Optimizer/Transforms/Passes.td | 21 +- flang/lib/Optimizer/Dialect/FIROps.cpp | 3 +- .../Optimizer/Transforms/AbstractResult.cpp | 288 ++++++++++++++++++ flang/lib/Optimizer/Transforms/CMakeLists.txt | 1 + flang/test/Fir/abstract-results.fir | 255 ++++++++++++++++ 6 files changed, 567 insertions(+), 2 deletions(-) create mode 100644 flang/lib/Optimizer/Transforms/AbstractResult.cpp create mode 100644 flang/test/Fir/abstract-results.fir diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h index fc689b037297..5dc784ff0b50 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -26,6 +26,7 @@ namespace fir { // Passes defined in Passes.td //===----------------------------------------------------------------------===// +std::unique_ptr createAbstractResultOptPass(); std::unique_ptr createAffineDemotionPass(); std::unique_ptr createCharacterConversionPass(); std::unique_ptr createExternalNameConversionPass(); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index b207ad70ba9a..309ef43d766d 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // This file contains definitions for passes within the Optimizer/Transforms/ -// directory. +// directory. // //===----------------------------------------------------------------------===// @@ -16,6 +16,25 @@ include "mlir/Pass/PassBase.td" +def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::FuncOp"> { + let summary = "Convert fir.array, fir.box and fir.rec function result to " + "function argument"; + let description = [{ + This pass is required before code gen to the LLVM IR dialect, + including the pre-cg rewrite pass. + }]; + let constructor = "::fir::createAbstractResultOptPass()"; + let dependentDialects = [ + "fir::FIROpsDialect", "mlir::StandardOpsDialect" + ]; + let options = [ + Option<"passResultAsBox", "abstract-result-as-box", + "bool", /*default=*/"false", + "Pass fir.array result as fir.box> argument instead" + " of fir.ref>."> + ]; +} + def AffineDialectPromotion : FunctionPass<"promote-to-affine"> { let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`."; let description = [{ diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 94e8d624b338..33db64c6687f 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -623,7 +623,8 @@ void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, llvm::ArrayRef results, mlir::ValueRange operands) { result.addOperands(operands); - result.addAttribute(getCalleeAttrName(), callee); + if (callee) + result.addAttribute(getCalleeAttrName(), callee); result.addTypes(results); } diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp new file mode 100644 index 000000000000..21df4180e14c --- /dev/null +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -0,0 +1,288 @@ +//===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "flang-abstract-result-opt" + +namespace fir { +namespace { + +struct AbstractResultOptions { + // Always pass result as a fir.box argument. + bool boxResult = false; + // New function block argument for the result if the current FuncOp had + // an abstract result. + mlir::Value newArg; +}; + +static bool mustConvertCallOrFunc(mlir::FunctionType type) { + if (type.getNumResults() == 0) + return false; + auto resultType = type.getResult(0); + return resultType.isa(); +} + +static mlir::Type getResultArgumentType(mlir::Type resultType, + const AbstractResultOptions &options) { + return llvm::TypeSwitch(resultType) + .Case( + [&](mlir::Type type) -> mlir::Type { + if (options.boxResult) + return fir::BoxType::get(type); + return fir::ReferenceType::get(type); + }) + .Case([](mlir::Type type) -> mlir::Type { + return fir::ReferenceType::get(type); + }) + .Default([](mlir::Type) -> mlir::Type { + llvm_unreachable("bad abstract result type"); + }); +} + +static mlir::FunctionType +getNewFunctionType(mlir::FunctionType funcTy, + const AbstractResultOptions &options) { + auto resultType = funcTy.getResult(0); + auto argTy = getResultArgumentType(resultType, options); + llvm::SmallVector newInputTypes = {argTy}; + newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); + return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, + /*resultTypes=*/{}); +} + +static bool mustEmboxResult(mlir::Type resultType, + const AbstractResultOptions &options) { + return resultType.isa() && + options.boxResult; +} + +class CallOpConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt) + : OpRewritePattern(context), options{opt} {} + mlir::LogicalResult + matchAndRewrite(fir::CallOp callOp, + mlir::PatternRewriter &rewriter) const override { + auto loc = callOp.getLoc(); + auto result = callOp->getResult(0); + if (!result.hasOneUse()) { + mlir::emitError(loc, + "calls with abstract result must have exactly one user"); + return mlir::failure(); + } + auto saveResult = + mlir::dyn_cast(result.use_begin().getUser()); + if (!saveResult) { + mlir::emitError( + loc, "calls with abstract result must be used in fir.save_result"); + return mlir::failure(); + } + auto argType = getResultArgumentType(result.getType(), options); + auto buffer = saveResult.memref(); + mlir::Value arg = buffer; + if (mustEmboxResult(result.getType(), options)) + arg = rewriter.create( + loc, argType, buffer, saveResult.shape(), /*slice*/ mlir::Value{}, + saveResult.typeparams()); + + llvm::SmallVector newResultTypes; + if (callOp.callee()) { + llvm::SmallVector newOperands = {arg}; + newOperands.append(callOp.getOperands().begin(), + callOp.getOperands().end()); + rewriter.create(loc, callOp.callee().getValue(), + newResultTypes, newOperands); + } else { + // Indirect calls. + llvm::SmallVector newInputTypes = {argType}; + for (auto operand : callOp.getOperands().drop_front()) + newInputTypes.push_back(operand.getType()); + auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes, + newResultTypes); + + llvm::SmallVector newOperands; + newOperands.push_back( + rewriter.create(loc, funTy, callOp.getOperand(0))); + newOperands.push_back(arg); + newOperands.append(callOp.getOperands().begin() + 1, + callOp.getOperands().end()); + rewriter.create(loc, mlir::SymbolRefAttr{}, newResultTypes, + newOperands); + } + callOp->dropAllReferences(); + rewriter.eraseOp(callOp); + return mlir::success(); + } + +private: + const AbstractResultOptions &options; +}; + +class SaveResultOpConversion + : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + SaveResultOpConversion(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(fir::SaveResultOp op, + mlir::PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class ReturnOpConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + ReturnOpConversion(mlir::MLIRContext *context, + const AbstractResultOptions &opt) + : OpRewritePattern(context), options{opt} {} + mlir::LogicalResult + matchAndRewrite(mlir::ReturnOp ret, + mlir::PatternRewriter &rewriter) const override { + rewriter.setInsertionPoint(ret); + auto returnedValue = ret.getOperand(0); + bool replacedStorage = false; + if (auto *op = returnedValue.getDefiningOp()) + if (auto load = mlir::dyn_cast(op)) { + auto resultStorage = load.memref(); + load.memref().replaceAllUsesWith(options.newArg); + replacedStorage = true; + if (auto *alloc = resultStorage.getDefiningOp()) + if (alloc->use_empty()) + rewriter.eraseOp(alloc); + } + // The result storage may have been optimized out by a memory to + // register pass, this is possible for fir.box results, or fir.record + // with no length parameters. Simply store the result in the result storage. + // at the return point. + if (!replacedStorage) + rewriter.create(ret.getLoc(), returnedValue, + options.newArg); + rewriter.replaceOpWithNewOp(ret); + return mlir::success(); + } + +private: + const AbstractResultOptions &options; +}; + +class AddrOfOpConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + AddrOfOpConversion(mlir::MLIRContext *context, + const AbstractResultOptions &opt) + : OpRewritePattern(context), options{opt} {} + mlir::LogicalResult + matchAndRewrite(fir::AddrOfOp addrOf, + mlir::PatternRewriter &rewriter) const override { + auto oldFuncTy = addrOf.getType().cast(); + auto newFuncTy = getNewFunctionType(oldFuncTy, options); + auto newAddrOf = rewriter.create(addrOf.getLoc(), newFuncTy, + addrOf.symbol()); + // Rather than converting all op a function pointer might transit through + // (e.g calls, stores, loads, converts...), cast new type to the abstract + // type. A conversion will be added when calling indirect calls of abstract + // types. + rewriter.replaceOpWithNewOp(addrOf, oldFuncTy, newAddrOf); + return mlir::success(); + } + +private: + const AbstractResultOptions &options; +}; + +class AbstractResultOpt : public fir::AbstractResultOptBase { +public: + void runOnOperation() override { + auto *context = &getContext(); + auto func = getOperation(); + auto loc = func.getLoc(); + mlir::OwningRewritePatternList patterns(context); + mlir::ConversionTarget target = *context; + AbstractResultOptions options{passResultAsBox.getValue(), + /*newArg=*/{}}; + + // Convert function type itself if it has an abstract result + auto funcTy = func.getType().cast(); + if (mustConvertCallOrFunc(funcTy)) { + func.setType(getNewFunctionType(funcTy, options)); + unsigned zero = 0; + if (!func.empty()) { + // Insert new argument + mlir::OpBuilder rewriter(context); + auto resultType = funcTy.getResult(0); + auto argTy = getResultArgumentType(resultType, options); + options.newArg = func.front().insertArgument(zero, argTy); + if (mustEmboxResult(resultType, options)) { + auto bufferType = fir::ReferenceType::get(resultType); + rewriter.setInsertionPointToStart(&func.front()); + options.newArg = + rewriter.create(loc, bufferType, options.newArg); + } + patterns.insert(context, options); + target.addDynamicallyLegalOp( + [](mlir::ReturnOp ret) { return ret.operands().empty(); }); + } + } + + if (func.empty()) + return; + + // Convert the calls and, if needed, the ReturnOp in the function body. + target.addLegalDialect(); + target.addIllegalOp(); + target.addDynamicallyLegalOp([](fir::CallOp call) { + return !mustConvertCallOrFunc(call.getFunctionType()); + }); + target.addDynamicallyLegalOp([](fir::AddrOfOp addrOf) { + if (auto funTy = addrOf.getType().dyn_cast()) + return !mustConvertCallOrFunc(funTy); + return true; + }); + target.addDynamicallyLegalOp([](fir::DispatchOp dispatch) { + if (dispatch->getNumResults() != 1) + return true; + auto resultType = dispatch->getResult(0).getType(); + if (resultType.isa()) { + mlir::emitError(dispatch.getLoc(), + "TODO: dispatchOp with abstract results"); + return false; + } + return true; + }); + + patterns.insert(context, options); + patterns.insert(context); + patterns.insert(context, options); + if (mlir::failed( + mlir::applyPartialConversion(func, target, std::move(patterns)))) { + mlir::emitError(func.getLoc(), "error in converting abstract results\n"); + signalPassFailure(); + } + } +}; +} // end anonymous namespace +} // namespace fir + +std::unique_ptr fir::createAbstractResultOptPass() { + return std::make_unique(); +} diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index 6465ba8c5599..99b022edb948 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_flang_library(FIRTransforms + AbstractResult.cpp AffinePromotion.cpp AffineDemotion.cpp CharacterConversion.cpp diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir new file mode 100644 index 000000000000..e7b24f268acd --- /dev/null +++ b/flang/test/Fir/abstract-results.fir @@ -0,0 +1,255 @@ +// Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to +// functions that take an additional argument for the result. + +// RUN: fir-opt %s --abstract-result-opt | FileCheck %s +// RUN: fir-opt %s --abstract-result-opt=abstract-result-as-box | FileCheck %s --check-prefix=CHECK-BOX + +// ----------------------- Test declaration rewrite ---------------------------- + +// CHECK-LABEL: func private @arrayfunc(!fir.ref>, i32) +// CHECK-BOX-LABEL: func private @arrayfunc(!fir.box>, i32) +func private @arrayfunc(i32) -> !fir.array + +// CHECK-LABEL: func private @derivedfunc(!fir.ref>, f32) +// CHECK-BOX-LABEL: func private @derivedfunc(!fir.box>, f32) +func private @derivedfunc(f32) -> !fir.type + +// CHECK-LABEL: func private @boxfunc(!fir.ref>>, i64) +// CHECK-BOX-LABEL: func private @boxfunc(!fir.ref>>, i64) +func private @boxfunc(i64) -> !fir.box> + + +// ------------------------ Test callee rewrite -------------------------------- + +// CHECK-LABEL: func private @arrayfunc_callee( +// CHECK-SAME: %[[buffer:.*]]: !fir.ref>, %[[n:.*]]: index) { +// CHECK-BOX-LABEL: func private @arrayfunc_callee( +// CHECK-BOX-SAME: %[[box:.*]]: !fir.box>, %[[n:.*]]: index) { +func private @arrayfunc_callee(%n : index) -> !fir.array { + %buffer = fir.alloca !fir.array, %n + // Do something with result (res(4) = 42.) + %c4 = constant 4 : i64 + %coor = fir.coordinate_of %buffer, %c4 : (!fir.ref>, i64) -> !fir.ref + %cst = constant 4.200000e+01 : f32 + fir.store %cst to %coor : !fir.ref + %res = fir.load %buffer : !fir.ref> + return %res : !fir.array + + // CHECK-DAG: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref>, i64) -> !fir.ref + // CHECK-DAG: fir.store %{{.*}} to %[[coor]] : !fir.ref + // CHECK: return + + // CHECK-BOX: %[[buffer:.*]] = fir.box_addr %[[box]] : (!fir.box>) -> !fir.ref> + // CHECK-BOX-DAG: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref>, i64) -> !fir.ref + // CHECK-BOX-DAG: fir.store %{{.*}} to %[[coor]] : !fir.ref + // CHECK-BOX: return +} + + +// CHECK-LABEL: func @derivedfunc_callee( +// CHECK-SAME: %[[buffer:.*]]: !fir.ref>, %[[v:.*]]: f32) { +// CHECK-BOX-LABEL: func @derivedfunc_callee( +// CHECK-BOX-SAME: %[[box:.*]]: !fir.box>, %[[v:.*]]: f32) { +func @derivedfunc_callee(%v: f32) -> !fir.type { + %buffer = fir.alloca !fir.type + %0 = fir.field_index x, !fir.type + %1 = fir.coordinate_of %buffer, %0 : (!fir.ref>, !fir.field) -> !fir.ref + fir.store %v to %1 : !fir.ref + %res = fir.load %buffer : !fir.ref> + return %res : !fir.type + + // CHECK: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref>, !fir.field) -> !fir.ref + // CHECK: fir.store %[[v]] to %[[coor]] : !fir.ref + // CHECK: return + + // CHECK-BOX: %[[buffer:.*]] = fir.box_addr %[[box]] : (!fir.box>) -> !fir.ref> + // CHECK-BOX: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref>, !fir.field) -> !fir.ref + // CHECK-BOX: fir.store %[[v]] to %[[coor]] : !fir.ref + // CHECK-BOX: return +} + +// CHECK-LABEL: func @boxfunc_callee( +// CHECK-SAME: %[[buffer:.*]]: !fir.ref>>) { +// CHECK-BOX-LABEL: func @boxfunc_callee( +// CHECK-BOX-SAME: %[[buffer:.*]]: !fir.ref>>) { +func @boxfunc_callee() -> !fir.box> { + %alloc = fir.allocmem f64 + %res = fir.embox %alloc : (!fir.heap) -> !fir.box> + return %res : !fir.box> + // CHECK: %[[box:.*]] = fir.embox %{{.*}} : (!fir.heap) -> !fir.box> + // CHECK: fir.store %[[box]] to %[[buffer]] : !fir.ref>> + // CHECK: return + + // CHECK-BOX: %[[box:.*]] = fir.embox %{{.*}} : (!fir.heap) -> !fir.box> + // CHECK-BOX: fir.store %[[box]] to %[[buffer]] : !fir.ref>> + // CHECK-BOX: return +} + +// ------------------------ Test caller rewrite -------------------------------- + +// CHECK-LABEL: func @call_arrayfunc() { +// CHECK-BOX-LABEL: func @call_arrayfunc() { +func @call_arrayfunc() { + %c100 = constant 100 : index + %buffer = fir.alloca !fir.array, %c100 + %shape = fir.shape %c100 : (index) -> !fir.shape<1> + %res = fir.call @arrayfunc_callee(%c100) : (index) -> !fir.array + fir.save_result %res to %buffer(%shape) : !fir.array, !fir.ref>, !fir.shape<1> + return + + // CHECK: %[[c100:.*]] = constant 100 : index + // CHECK: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] + // CHECK: fir.call @arrayfunc_callee(%[[buffer]], %[[c100]]) : (!fir.ref>, index) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[c100:.*]] = constant 100 : index + // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] + // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> + // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> + // CHECK-BOX: fir.call @arrayfunc_callee(%[[box]], %[[c100]]) : (!fir.box>, index) -> () + // CHECK-BOX-NOT: fir.save_result +} + +// CHECK-LABEL: func @call_derivedfunc() { +// CHECK-BOX-LABEL: func @call_derivedfunc() { +func @call_derivedfunc() { + %buffer = fir.alloca !fir.type + %cst = constant 4.200000e+01 : f32 + %res = fir.call @derivedfunc_callee(%cst) : (f32) -> !fir.type + fir.save_result %res to %buffer : !fir.type, !fir.ref> + return + // CHECK: %[[buffer:.*]] = fir.alloca !fir.type + // CHECK: %[[cst:.*]] = constant {{.*}} : f32 + // CHECK: fir.call @derivedfunc_callee(%[[buffer]], %[[cst]]) : (!fir.ref>, f32) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.type + // CHECK-BOX: %[[cst:.*]] = constant {{.*}} : f32 + // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]] : (!fir.ref>) -> !fir.box> + // CHECK-BOX: fir.call @derivedfunc_callee(%[[box]], %[[cst]]) : (!fir.box>, f32) -> () + // CHECK-BOX-NOT: fir.save_result +} + +func private @derived_lparams_func() -> !fir.type + +// CHECK-LABEL: func @call_derived_lparams_func( +// CHECK-SAME: %[[buffer:.*]]: !fir.ref> +// CHECK-BOX-LABEL: func @call_derived_lparams_func( +// CHECK-BOX-SAME: %[[buffer:.*]]: !fir.ref> +func @call_derived_lparams_func(%buffer: !fir.ref>) { + %l1 = constant 3 : i32 + %l2 = constant 5 : i32 + %res = fir.call @derived_lparams_func() : () -> !fir.type + fir.save_result %res to %buffer typeparams %l1, %l2 : !fir.type, !fir.ref>, i32, i32 + return + + // CHECK: %[[l1:.*]] = constant 3 : i32 + // CHECK: %[[l2:.*]] = constant 5 : i32 + // CHECK: fir.call @derived_lparams_func(%[[buffer]]) : (!fir.ref>) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[l1:.*]] = constant 3 : i32 + // CHECK-BOX: %[[l2:.*]] = constant 5 : i32 + // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]] typeparams %[[l1]], %[[l2]] : (!fir.ref>, i32, i32) -> !fir.box> + // CHECK-BOX: fir.call @derived_lparams_func(%[[box]]) : (!fir.box>) -> () + // CHECK-BOX-NOT: fir.save_result +} + +// CHECK-LABEL: func @call_boxfunc() { +// CHECK-BOX-LABEL: func @call_boxfunc() { +func @call_boxfunc() { + %buffer = fir.alloca !fir.box> + %res = fir.call @boxfunc_callee() : () -> !fir.box> + fir.save_result %res to %buffer: !fir.box>, !fir.ref>> + return + + // CHECK: %[[buffer:.*]] = fir.alloca !fir.box> + // CHECK: fir.call @boxfunc_callee(%[[buffer]]) : (!fir.ref>>) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.box> + // CHECK-BOX: fir.call @boxfunc_callee(%[[buffer]]) : (!fir.ref>>) -> () + // CHECK-BOX-NOT: fir.save_result +} + +func private @chararrayfunc(index, index) -> !fir.array> + +// CHECK-LABEL: func @call_chararrayfunc() { +// CHECK-BOX-LABEL: func @call_chararrayfunc() { +func @call_chararrayfunc() { + %c100 = constant 100 : index + %c50 = constant 50 : index + %buffer = fir.alloca !fir.array>(%c100 : index), %c50 + %shape = fir.shape %c100 : (index) -> !fir.shape<1> + %res = fir.call @chararrayfunc(%c100, %c50) : (index, index) -> !fir.array> + fir.save_result %res to %buffer(%shape) typeparams %c50 : !fir.array>, !fir.ref>>, !fir.shape<1>, index + return + + // CHECK: %[[c100:.*]] = constant 100 : index + // CHECK: %[[c50:.*]] = constant 50 : index + // CHECK: %[[buffer:.*]] = fir.alloca !fir.array>(%[[c100]] : index), %[[c50]] + // CHECK: fir.call @chararrayfunc(%[[buffer]], %[[c100]], %[[c50]]) : (!fir.ref>>, index, index) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[c100:.*]] = constant 100 : index + // CHECK-BOX: %[[c50:.*]] = constant 50 : index + // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array>(%[[c100]] : index), %[[c50]] + // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> + // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) typeparams %[[c50]] : (!fir.ref>>, !fir.shape<1>, index) -> !fir.box>> + // CHECK-BOX: fir.call @chararrayfunc(%[[box]], %[[c100]], %[[c50]]) : (!fir.box>>, index, index) -> () + // CHECK-BOX-NOT: fir.save_result +} + +// ------------------------ Test fir.address_of rewrite ------------------------ + +func private @takesfuncarray((i32) -> !fir.array) + +// CHECK-LABEL: func @test_address_of() { +// CHECK-BOX-LABEL: func @test_address_of() { +func @test_address_of() { + %0 = fir.address_of(@arrayfunc) : (i32) -> !fir.array + fir.call @takesfuncarray(%0) : ((i32) -> !fir.array) -> () + return + + // CHECK: %[[addrOf:.*]] = fir.address_of(@arrayfunc) : (!fir.ref>, i32) -> () + // CHECK: %[[conv:.*]] = fir.convert %[[addrOf]] : ((!fir.ref>, i32) -> ()) -> ((i32) -> !fir.array) + // CHECK: fir.call @takesfuncarray(%[[conv]]) : ((i32) -> !fir.array) -> () + + // CHECK-BOX: %[[addrOf:.*]] = fir.address_of(@arrayfunc) : (!fir.box>, i32) -> () + // CHECK-BOX: %[[conv:.*]] = fir.convert %[[addrOf]] : ((!fir.box>, i32) -> ()) -> ((i32) -> !fir.array) + // CHECK-BOX: fir.call @takesfuncarray(%[[conv]]) : ((i32) -> !fir.array) -> () + +} + +// ----------------------- Test indirect calls rewrite ------------------------ + +// CHECK-LABEL: func @test_indirect_calls( +// CHECK-SAME: %[[arg0:.*]]: () -> ()) { +// CHECK-BOX-LABEL: func @test_indirect_calls( +// CHECK-BOX-SAME: %[[arg0:.*]]: () -> ()) { +func @test_indirect_calls(%arg0: () -> ()) { + %c100 = constant 100 : index + %buffer = fir.alloca !fir.array, %c100 + %shape = fir.shape %c100 : (index) -> !fir.shape<1> + %0 = fir.convert %arg0 : (() -> ()) -> ((index) -> !fir.array) + %res = fir.call %0(%c100) : (index) -> !fir.array + fir.save_result %res to %buffer(%shape) : !fir.array, !fir.ref>, !fir.shape<1> + return + + // CHECK: %[[c100:.*]] = constant 100 : index + // CHECK: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] + // CHECK: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> + // CHECK: %[[original_conv:.*]] = fir.convert %[[arg0]] : (() -> ()) -> ((index) -> !fir.array) + // CHECK: %[[conv:.*]] = fir.convert %[[original_conv]] : ((index) -> !fir.array) -> ((!fir.ref>, index) -> ()) + // CHECK: fir.call %[[conv]](%[[buffer]], %c100) : (!fir.ref>, index) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[c100:.*]] = constant 100 : index + // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] + // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> + // CHECK-BOX: %[[original_conv:.*]] = fir.convert %[[arg0]] : (() -> ()) -> ((index) -> !fir.array) + // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> + // CHECK-BOX: %[[conv:.*]] = fir.convert %[[original_conv]] : ((index) -> !fir.array) -> ((!fir.box>, index) -> ()) + // CHECK-BOX: fir.call %[[conv]](%[[box]], %c100) : (!fir.box>, index) -> () + // CHECK-BOX-NOT: fir.save_result +}