llvm-project/flang/lib/Optimizer/Transforms/AbstractResult.cpp

290 lines
11 KiB
C++

//===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
#define DEBUG_TYPE "flang-abstract-result-opt"
namespace fir {
namespace {
struct AbstractResultOptions {
// Always pass result as a fir.box argument.
bool boxResult = false;
// New function block argument for the result if the current FuncOp had
// an abstract result.
mlir::Value newArg;
};
static bool mustConvertCallOrFunc(mlir::FunctionType type) {
if (type.getNumResults() == 0)
return false;
auto resultType = type.getResult(0);
return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>();
}
static mlir::Type getResultArgumentType(mlir::Type resultType,
const AbstractResultOptions &options) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
.Case<fir::SequenceType, fir::RecordType>(
[&](mlir::Type type) -> mlir::Type {
if (options.boxResult)
return fir::BoxType::get(type);
return fir::ReferenceType::get(type);
})
.Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
return fir::ReferenceType::get(type);
})
.Default([](mlir::Type) -> mlir::Type {
llvm_unreachable("bad abstract result type");
});
}
static mlir::FunctionType
getNewFunctionType(mlir::FunctionType funcTy,
const AbstractResultOptions &options) {
auto resultType = funcTy.getResult(0);
auto argTy = getResultArgumentType(resultType, options);
llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
/*resultTypes=*/{});
}
static bool mustEmboxResult(mlir::Type resultType,
const AbstractResultOptions &options) {
return resultType.isa<fir::SequenceType, fir::RecordType>() &&
options.boxResult;
}
class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
public:
using OpRewritePattern::OpRewritePattern;
CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt)
: OpRewritePattern(context), options{opt} {}
mlir::LogicalResult
matchAndRewrite(fir::CallOp callOp,
mlir::PatternRewriter &rewriter) const override {
auto loc = callOp.getLoc();
auto result = callOp->getResult(0);
if (!result.hasOneUse()) {
mlir::emitError(loc,
"calls with abstract result must have exactly one user");
return mlir::failure();
}
auto saveResult =
mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
if (!saveResult) {
mlir::emitError(
loc, "calls with abstract result must be used in fir.save_result");
return mlir::failure();
}
auto argType = getResultArgumentType(result.getType(), options);
auto buffer = saveResult.getMemref();
mlir::Value arg = buffer;
if (mustEmboxResult(result.getType(), options))
arg = rewriter.create<fir::EmboxOp>(
loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
saveResult.getTypeparams());
llvm::SmallVector<mlir::Type> newResultTypes;
if (callOp.getCallee()) {
llvm::SmallVector<mlir::Value> newOperands = {arg};
newOperands.append(callOp.getOperands().begin(),
callOp.getOperands().end());
rewriter.create<fir::CallOp>(loc, callOp.getCallee().getValue(),
newResultTypes, newOperands);
} else {
// Indirect calls.
llvm::SmallVector<mlir::Type> newInputTypes = {argType};
for (auto operand : callOp.getOperands().drop_front())
newInputTypes.push_back(operand.getType());
auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes,
newResultTypes);
llvm::SmallVector<mlir::Value> newOperands;
newOperands.push_back(
rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0)));
newOperands.push_back(arg);
newOperands.append(callOp.getOperands().begin() + 1,
callOp.getOperands().end());
rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes,
newOperands);
}
callOp->dropAllReferences();
rewriter.eraseOp(callOp);
return mlir::success();
}
private:
const AbstractResultOptions &options;
};
class SaveResultOpConversion
: public mlir::OpRewritePattern<fir::SaveResultOp> {
public:
using OpRewritePattern::OpRewritePattern;
SaveResultOpConversion(mlir::MLIRContext *context)
: OpRewritePattern(context) {}
mlir::LogicalResult
matchAndRewrite(fir::SaveResultOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return mlir::success();
}
};
class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
public:
using OpRewritePattern::OpRewritePattern;
ReturnOpConversion(mlir::MLIRContext *context,
const AbstractResultOptions &opt)
: OpRewritePattern(context), options{opt} {}
mlir::LogicalResult
matchAndRewrite(mlir::func::ReturnOp ret,
mlir::PatternRewriter &rewriter) const override {
rewriter.setInsertionPoint(ret);
auto returnedValue = ret.getOperand(0);
bool replacedStorage = false;
if (auto *op = returnedValue.getDefiningOp())
if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
auto resultStorage = load.getMemref();
load.getMemref().replaceAllUsesWith(options.newArg);
replacedStorage = true;
if (auto *alloc = resultStorage.getDefiningOp())
if (alloc->use_empty())
rewriter.eraseOp(alloc);
}
// The result storage may have been optimized out by a memory to
// register pass, this is possible for fir.box results, or fir.record
// with no length parameters. Simply store the result in the result storage.
// at the return point.
if (!replacedStorage)
rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue,
options.newArg);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
return mlir::success();
}
private:
const AbstractResultOptions &options;
};
class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
public:
using OpRewritePattern::OpRewritePattern;
AddrOfOpConversion(mlir::MLIRContext *context,
const AbstractResultOptions &opt)
: OpRewritePattern(context), options{opt} {}
mlir::LogicalResult
matchAndRewrite(fir::AddrOfOp addrOf,
mlir::PatternRewriter &rewriter) const override {
auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
auto newFuncTy = getNewFunctionType(oldFuncTy, options);
auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
addrOf.getSymbol());
// Rather than converting all op a function pointer might transit through
// (e.g calls, stores, loads, converts...), cast new type to the abstract
// type. A conversion will be added when calling indirect calls of abstract
// types.
rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
return mlir::success();
}
private:
const AbstractResultOptions &options;
};
class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
public:
void runOnOperation() override {
auto *context = &getContext();
auto func = getOperation();
auto loc = func.getLoc();
mlir::RewritePatternSet patterns(context);
mlir::ConversionTarget target = *context;
AbstractResultOptions options{passResultAsBox.getValue(),
/*newArg=*/{}};
// Convert function type itself if it has an abstract result
auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
if (mustConvertCallOrFunc(funcTy)) {
func.setType(getNewFunctionType(funcTy, options));
unsigned zero = 0;
if (!func.empty()) {
// Insert new argument
mlir::OpBuilder rewriter(context);
auto resultType = funcTy.getResult(0);
auto argTy = getResultArgumentType(resultType, options);
options.newArg = func.front().insertArgument(zero, argTy, loc);
if (mustEmboxResult(resultType, options)) {
auto bufferType = fir::ReferenceType::get(resultType);
rewriter.setInsertionPointToStart(&func.front());
options.newArg =
rewriter.create<fir::BoxAddrOp>(loc, bufferType, options.newArg);
}
patterns.insert<ReturnOpConversion>(context, options);
target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
[](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
}
}
if (func.empty())
return;
// Convert the calls and, if needed, the ReturnOp in the function body.
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
mlir::func::FuncDialect>();
target.addIllegalOp<fir::SaveResultOp>();
target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
return !mustConvertCallOrFunc(call.getFunctionType());
});
target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
return !mustConvertCallOrFunc(funTy);
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
if (dispatch->getNumResults() != 1)
return true;
auto resultType = dispatch->getResult(0).getType();
if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
mlir::emitError(dispatch.getLoc(),
"TODO: dispatchOp with abstract results");
return false;
}
return true;
});
patterns.insert<CallOpConversion>(context, options);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, options);
if (mlir::failed(
mlir::applyPartialConversion(func, target, std::move(patterns)))) {
mlir::emitError(func.getLoc(), "error in converting abstract results\n");
signalPassFailure();
}
}
};
} // end anonymous namespace
} // namespace fir
std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
return std::make_unique<AbstractResultOpt>();
}