forked from OSchip/llvm-project
[FLANG][NFCI]De-duplicate code in SimplifyIntrinsics
This removes a bunch of duplicated code, by adding an intermediate function simplifyReduction that takes a std::function argument for the actual replacement of the code. No functional change intended. Reviewed By: vzakhari Differential Revision: https://reviews.llvm.org/D132588
This commit is contained in:
parent
10dfcf1f87
commit
43159b5808
|
@ -52,9 +52,11 @@ namespace {
|
|||
class SimplifyIntrinsicsPass
|
||||
: public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
|
||||
using FunctionTypeGeneratorTy =
|
||||
std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
|
||||
llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
|
||||
using FunctionBodyGeneratorTy =
|
||||
std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
|
||||
llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
|
||||
using GenReductionBodyTy = llvm::function_ref<void(
|
||||
fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp)>;
|
||||
|
||||
public:
|
||||
/// Generate a new function implementing a simplified version
|
||||
|
@ -68,6 +70,16 @@ public:
|
|||
FunctionBodyGeneratorTy bodyGenerator);
|
||||
void runOnOperation() override;
|
||||
void getDependentDialects(mlir::DialectRegistry ®istry) const override;
|
||||
|
||||
private:
|
||||
/// Helper function to replace a reduction type of call with its
|
||||
/// simplified form. The actual function is generated using a callback
|
||||
/// function.
|
||||
/// \p call is the call to be replaced
|
||||
/// \p kindMap is used to create FIROpBuilder
|
||||
/// \p genBodyFunc is the callback that builds the replacement function
|
||||
void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
|
||||
GenReductionBodyTy genBodyFunc);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@ -81,10 +93,10 @@ static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
|
|||
{elementType});
|
||||
}
|
||||
|
||||
using BodyOpGeneratorTy =
|
||||
std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location,
|
||||
const mlir::Type &, mlir::Value, mlir::Value)>;
|
||||
using InitValGeneratorTy = std::function<mlir::Value(
|
||||
using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
|
||||
fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
|
||||
mlir::Value)>;
|
||||
using InitValGeneratorTy = llvm::function_ref<mlir::Value(
|
||||
fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
|
||||
|
||||
/// Generate the reduction loop into \p funcOp.
|
||||
|
@ -432,6 +444,43 @@ static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
|
|||
} while (true);
|
||||
}
|
||||
|
||||
void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
|
||||
const fir::KindMapping &kindMap,
|
||||
GenReductionBodyTy genBodyFunc) {
|
||||
mlir::SymbolRefAttr callee = call.getCalleeAttr();
|
||||
mlir::StringRef funcName = callee.getLeafReference().getValue();
|
||||
mlir::Operation::operand_range args = call.getArgs();
|
||||
// args[1] and args[2] are source filename and line number, ignored.
|
||||
const mlir::Value &dim = args[3];
|
||||
const mlir::Value &mask = args[4];
|
||||
// dim is zero when it is absent, which is an implementation
|
||||
// detail in the runtime library.
|
||||
bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
|
||||
unsigned rank = getDimCount(args[0]);
|
||||
if (dimAndMaskAbsent && rank == 1) {
|
||||
mlir::Location loc = call.getLoc();
|
||||
mlir::Type type;
|
||||
fir::FirOpBuilder builder(call, kindMap);
|
||||
if (funcName.endswith("Integer4")) {
|
||||
type = mlir::IntegerType::get(builder.getContext(), 32);
|
||||
} else if (funcName.endswith("Real8")) {
|
||||
type = mlir::FloatType::getF64(builder.getContext());
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
|
||||
return genNoneBoxType(builder, type);
|
||||
};
|
||||
mlir::func::FuncOp newFunc =
|
||||
getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc);
|
||||
auto newCall =
|
||||
builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
|
||||
call->replaceAllUsesWith(newCall.getResults());
|
||||
call->dropAllReferences();
|
||||
call->erase();
|
||||
}
|
||||
}
|
||||
|
||||
void SimplifyIntrinsicsPass::runOnOperation() {
|
||||
LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
|
||||
mlir::ModuleOp module = getOperation();
|
||||
|
@ -450,37 +499,7 @@ void SimplifyIntrinsicsPass::runOnOperation() {
|
|||
// int dim, const Descriptor *mask)
|
||||
//
|
||||
if (funcName.startswith("_FortranASum")) {
|
||||
mlir::Operation::operand_range args = call.getArgs();
|
||||
// args[1] and args[2] are source filename and line number, ignored.
|
||||
const mlir::Value &dim = args[3];
|
||||
const mlir::Value &mask = args[4];
|
||||
// dim is zero when it is absent, which is an implementation
|
||||
// detail in the runtime library.
|
||||
bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
|
||||
unsigned rank = getDimCount(args[0]);
|
||||
if (dimAndMaskAbsent && rank == 1) {
|
||||
mlir::Location loc = call.getLoc();
|
||||
mlir::Type type;
|
||||
fir::FirOpBuilder builder(op, kindMap);
|
||||
if (funcName.endswith("Integer4")) {
|
||||
type = mlir::IntegerType::get(builder.getContext(), 32);
|
||||
} else if (funcName.endswith("Real8")) {
|
||||
type = mlir::FloatType::getF64(builder.getContext());
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
|
||||
return genNoneBoxType(builder, type);
|
||||
};
|
||||
mlir::func::FuncOp newFunc = getOrCreateFunction(
|
||||
builder, funcName, typeGenerator, genFortranASumBody);
|
||||
auto newCall = builder.create<fir::CallOp>(
|
||||
loc, newFunc, mlir::ValueRange{args[0]});
|
||||
call->replaceAllUsesWith(newCall.getResults());
|
||||
call->dropAllReferences();
|
||||
call->erase();
|
||||
}
|
||||
|
||||
simplifyReduction(call, kindMap, genFortranASumBody);
|
||||
return;
|
||||
}
|
||||
if (funcName.startswith("_FortranADotProduct")) {
|
||||
|
@ -544,37 +563,9 @@ void SimplifyIntrinsicsPass::runOnOperation() {
|
|||
return;
|
||||
}
|
||||
if (funcName.startswith("_FortranAMaxval")) {
|
||||
mlir::Operation::operand_range args = call.getArgs();
|
||||
const mlir::Value &dim = args[3];
|
||||
const mlir::Value &mask = args[4];
|
||||
// dim is zero when it is absent, which is an implementation
|
||||
// detail in the runtime library.
|
||||
bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
|
||||
unsigned rank = getDimCount(args[0]);
|
||||
if (dimAndMaskAbsent && rank == 1) {
|
||||
mlir::Location loc = call.getLoc();
|
||||
mlir::Type type;
|
||||
fir::FirOpBuilder builder(op, kindMap);
|
||||
if (funcName.endswith("Integer4")) {
|
||||
type = mlir::IntegerType::get(builder.getContext(), 32);
|
||||
} else if (funcName.endswith("Real8")) {
|
||||
type = mlir::FloatType::getF64(builder.getContext());
|
||||
} else {
|
||||
simplifyReduction(call, kindMap, genFortranAMaxvalBody);
|
||||
return;
|
||||
}
|
||||
auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
|
||||
return genNoneBoxType(builder, type);
|
||||
};
|
||||
mlir::func::FuncOp newFunc = getOrCreateFunction(
|
||||
builder, funcName, typeGenerator, genFortranAMaxvalBody);
|
||||
auto newCall = builder.create<fir::CallOp>(
|
||||
loc, newFunc, mlir::ValueRange{args[0]});
|
||||
call->replaceAllUsesWith(newCall.getResults());
|
||||
call->dropAllReferences();
|
||||
call->erase();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue