[FLANG][NFCI]De-duplicate code in SimplifyIntrinsics

This removes a bunch of duplicated code, by adding an intermediate
function simplifyReduction that takes a std::function argument
for the actual replacement of the code.

No functional change intended.

Reviewed By: vzakhari

Differential Revision: https://reviews.llvm.org/D132588
This commit is contained in:
Mats Petersson 2022-08-19 17:45:35 +01:00
parent 10dfcf1f87
commit 43159b5808
1 changed files with 58 additions and 67 deletions

View File

@ -52,9 +52,11 @@ namespace {
class SimplifyIntrinsicsPass
: public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
using FunctionTypeGeneratorTy =
std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
using FunctionBodyGeneratorTy =
std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
using GenReductionBodyTy = llvm::function_ref<void(
fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp)>;
public:
/// Generate a new function implementing a simplified version
@ -68,6 +70,16 @@ public:
FunctionBodyGeneratorTy bodyGenerator);
void runOnOperation() override;
void getDependentDialects(mlir::DialectRegistry &registry) const override;
private:
/// Helper function to replace a reduction type of call with its
/// simplified form. The actual function is generated using a callback
/// function.
/// \p call is the call to be replaced
/// \p kindMap is used to create FIROpBuilder
/// \p genBodyFunc is the callback that builds the replacement function
void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
GenReductionBodyTy genBodyFunc);
};
} // namespace
@ -81,10 +93,10 @@ static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
{elementType});
}
using BodyOpGeneratorTy =
std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location,
const mlir::Type &, mlir::Value, mlir::Value)>;
using InitValGeneratorTy = std::function<mlir::Value(
using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
mlir::Value)>;
using InitValGeneratorTy = llvm::function_ref<mlir::Value(
fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
/// Generate the reduction loop into \p funcOp.
@ -432,6 +444,43 @@ static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
} while (true);
}
void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
const fir::KindMapping &kindMap,
GenReductionBodyTy genBodyFunc) {
mlir::SymbolRefAttr callee = call.getCalleeAttr();
mlir::StringRef funcName = callee.getLeafReference().getValue();
mlir::Operation::operand_range args = call.getArgs();
// args[1] and args[2] are source filename and line number, ignored.
const mlir::Value &dim = args[3];
const mlir::Value &mask = args[4];
// dim is zero when it is absent, which is an implementation
// detail in the runtime library.
bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
unsigned rank = getDimCount(args[0]);
if (dimAndMaskAbsent && rank == 1) {
mlir::Location loc = call.getLoc();
mlir::Type type;
fir::FirOpBuilder builder(call, kindMap);
if (funcName.endswith("Integer4")) {
type = mlir::IntegerType::get(builder.getContext(), 32);
} else if (funcName.endswith("Real8")) {
type = mlir::FloatType::getF64(builder.getContext());
} else {
return;
}
auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
return genNoneBoxType(builder, type);
};
mlir::func::FuncOp newFunc =
getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc);
auto newCall =
builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
call->replaceAllUsesWith(newCall.getResults());
call->dropAllReferences();
call->erase();
}
}
void SimplifyIntrinsicsPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
mlir::ModuleOp module = getOperation();
@ -450,37 +499,7 @@ void SimplifyIntrinsicsPass::runOnOperation() {
// int dim, const Descriptor *mask)
//
if (funcName.startswith("_FortranASum")) {
mlir::Operation::operand_range args = call.getArgs();
// args[1] and args[2] are source filename and line number, ignored.
const mlir::Value &dim = args[3];
const mlir::Value &mask = args[4];
// dim is zero when it is absent, which is an implementation
// detail in the runtime library.
bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
unsigned rank = getDimCount(args[0]);
if (dimAndMaskAbsent && rank == 1) {
mlir::Location loc = call.getLoc();
mlir::Type type;
fir::FirOpBuilder builder(op, kindMap);
if (funcName.endswith("Integer4")) {
type = mlir::IntegerType::get(builder.getContext(), 32);
} else if (funcName.endswith("Real8")) {
type = mlir::FloatType::getF64(builder.getContext());
} else {
return;
}
auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
return genNoneBoxType(builder, type);
};
mlir::func::FuncOp newFunc = getOrCreateFunction(
builder, funcName, typeGenerator, genFortranASumBody);
auto newCall = builder.create<fir::CallOp>(
loc, newFunc, mlir::ValueRange{args[0]});
call->replaceAllUsesWith(newCall.getResults());
call->dropAllReferences();
call->erase();
}
simplifyReduction(call, kindMap, genFortranASumBody);
return;
}
if (funcName.startswith("_FortranADotProduct")) {
@ -544,37 +563,9 @@ void SimplifyIntrinsicsPass::runOnOperation() {
return;
}
if (funcName.startswith("_FortranAMaxval")) {
mlir::Operation::operand_range args = call.getArgs();
const mlir::Value &dim = args[3];
const mlir::Value &mask = args[4];
// dim is zero when it is absent, which is an implementation
// detail in the runtime library.
bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
unsigned rank = getDimCount(args[0]);
if (dimAndMaskAbsent && rank == 1) {
mlir::Location loc = call.getLoc();
mlir::Type type;
fir::FirOpBuilder builder(op, kindMap);
if (funcName.endswith("Integer4")) {
type = mlir::IntegerType::get(builder.getContext(), 32);
} else if (funcName.endswith("Real8")) {
type = mlir::FloatType::getF64(builder.getContext());
} else {
simplifyReduction(call, kindMap, genFortranAMaxvalBody);
return;
}
auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
return genNoneBoxType(builder, type);
};
mlir::func::FuncOp newFunc = getOrCreateFunction(
builder, funcName, typeGenerator, genFortranAMaxvalBody);
auto newCall = builder.create<fir::CallOp>(
loc, newFunc, mlir::ValueRange{args[0]});
call->replaceAllUsesWith(newCall.getResults());
call->dropAllReferences();
call->erase();
return;
}
}
}
}
});