[flang][NFC] Drop `AbstractResultOptions` structure

`AbstractResultOptions` is obsolete structure because `newArg` is used
only in `ReturnOpConversion`.
This change removes this struct, making dependencies of conversions more
straight-forward.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D129485
This commit is contained in:
Daniil Dudkin 2022-07-19 17:22:39 +03:00
parent 9fb33d52b0
commit ea1cdb58cc
1 changed files with 30 additions and 45 deletions

View File

@ -24,20 +24,12 @@
namespace fir {
namespace {
struct AbstractResultOptions {
// Always pass result as a fir.box argument.
bool boxResult = false;
// New function block argument for the result if the current FuncOp had
// an abstract result.
mlir::Value newArg;
};
static mlir::Type getResultArgumentType(mlir::Type resultType,
const AbstractResultOptions &options) {
bool shouldBoxResult) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
.Case<fir::SequenceType, fir::RecordType>(
[&](mlir::Type type) -> mlir::Type {
if (options.boxResult)
if (shouldBoxResult)
return fir::BoxType::get(type);
return fir::ReferenceType::get(type);
})
@ -49,28 +41,26 @@ static mlir::Type getResultArgumentType(mlir::Type resultType,
});
}
static mlir::FunctionType
getNewFunctionType(mlir::FunctionType funcTy,
const AbstractResultOptions &options) {
static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
bool shouldBoxResult) {
auto resultType = funcTy.getResult(0);
auto argTy = getResultArgumentType(resultType, options);
auto argTy = getResultArgumentType(resultType, shouldBoxResult);
llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
/*resultTypes=*/{});
}
static bool mustEmboxResult(mlir::Type resultType,
const AbstractResultOptions &options) {
static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
return resultType.isa<fir::SequenceType, fir::RecordType>() &&
options.boxResult;
shouldBoxResult;
}
class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
public:
using OpRewritePattern::OpRewritePattern;
CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt)
: OpRewritePattern(context), options{opt} {}
CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
: OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
mlir::LogicalResult
matchAndRewrite(fir::CallOp callOp,
mlir::PatternRewriter &rewriter) const override {
@ -88,10 +78,10 @@ public:
loc, "calls with abstract result must be used in fir.save_result");
return mlir::failure();
}
auto argType = getResultArgumentType(result.getType(), options);
auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
auto buffer = saveResult.getMemref();
mlir::Value arg = buffer;
if (mustEmboxResult(result.getType(), options))
if (mustEmboxResult(result.getType(), shouldBoxResult))
arg = rewriter.create<fir::EmboxOp>(
loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
saveResult.getTypeparams());
@ -126,7 +116,7 @@ public:
}
private:
const AbstractResultOptions &options;
bool shouldBoxResult;
};
class SaveResultOpConversion
@ -146,9 +136,8 @@ public:
class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
public:
using OpRewritePattern::OpRewritePattern;
ReturnOpConversion(mlir::MLIRContext *context,
const AbstractResultOptions &opt)
: OpRewritePattern(context), options{opt} {}
ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
: OpRewritePattern(context), newArg{newArg} {}
mlir::LogicalResult
matchAndRewrite(mlir::func::ReturnOp ret,
mlir::PatternRewriter &rewriter) const override {
@ -158,7 +147,7 @@ public:
if (auto *op = returnedValue.getDefiningOp())
if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
auto resultStorage = load.getMemref();
load.getMemref().replaceAllUsesWith(options.newArg);
load.getMemref().replaceAllUsesWith(newArg);
replacedStorage = true;
if (auto *alloc = resultStorage.getDefiningOp())
if (alloc->use_empty())
@ -169,27 +158,25 @@ public:
// with no length parameters. Simply store the result in the result storage.
// at the return point.
if (!replacedStorage)
rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue,
options.newArg);
rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, newArg);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
return mlir::success();
}
private:
const AbstractResultOptions &options;
mlir::Value newArg;
};
class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
public:
using OpRewritePattern::OpRewritePattern;
AddrOfOpConversion(mlir::MLIRContext *context,
const AbstractResultOptions &opt)
: OpRewritePattern(context), options{opt} {}
AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
: OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
mlir::LogicalResult
matchAndRewrite(fir::AddrOfOp addrOf,
mlir::PatternRewriter &rewriter) const override {
auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
auto newFuncTy = getNewFunctionType(oldFuncTy, options);
auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
addrOf.getSymbol());
// Rather than converting all op a function pointer might transit through
@ -201,7 +188,7 @@ public:
}
private:
const AbstractResultOptions &options;
bool shouldBoxResult;
};
class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
@ -212,27 +199,25 @@ public:
auto loc = func.getLoc();
mlir::RewritePatternSet patterns(context);
mlir::ConversionTarget target = *context;
AbstractResultOptions options{passResultAsBox.getValue(),
/*newArg=*/{}};
const bool shouldBoxResult = passResultAsBox.getValue();
// Convert function type itself if it has an abstract result
auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
if (hasAbstractResult(funcTy)) {
func.setType(getNewFunctionType(funcTy, options));
func.setType(getNewFunctionType(funcTy, shouldBoxResult));
unsigned zero = 0;
if (!func.empty()) {
// Insert new argument
mlir::OpBuilder rewriter(context);
auto resultType = funcTy.getResult(0);
auto argTy = getResultArgumentType(resultType, options);
options.newArg = func.front().insertArgument(zero, argTy, loc);
if (mustEmboxResult(resultType, options)) {
auto argTy = getResultArgumentType(resultType, shouldBoxResult);
mlir::Value newArg = func.front().insertArgument(zero, argTy, loc);
if (mustEmboxResult(resultType, shouldBoxResult)) {
auto bufferType = fir::ReferenceType::get(resultType);
rewriter.setInsertionPointToStart(&func.front());
options.newArg =
rewriter.create<fir::BoxAddrOp>(loc, bufferType, options.newArg);
newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
}
patterns.insert<ReturnOpConversion>(context, options);
patterns.insert<ReturnOpConversion>(context, newArg);
target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
[](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
}
@ -264,9 +249,9 @@ public:
return true;
});
patterns.insert<CallOpConversion>(context, options);
patterns.insert<CallOpConversion>(context, shouldBoxResult);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, options);
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
if (mlir::failed(
mlir::applyPartialConversion(func, target, std::move(patterns)))) {
mlir::emitError(func.getLoc(), "error in converting abstract results\n");