[MLIR] Migrate Arithmetic -> LLVM conversion pass to the auto-generated constructor

See #57475

Differential Revision: https://reviews.llvm.org/D134752
This commit is contained in:
Michele Scuttari 2022-09-27 19:41:08 +02:00
parent c29d911fd3
commit b1ce63bb5f
No known key found for this signature in database
GPG Key ID: E79E7BDFEE4B62D4
5 changed files with 10 additions and 17 deletions

View File

@ -17,14 +17,12 @@ class LLVMTypeConverter;
class RewritePatternSet; class RewritePatternSet;
class Pass; class Pass;
#define GEN_PASS_DECL_CONVERTARITHMETICTOLLVM #define GEN_PASS_DECL_ARITHMETICTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc" #include "mlir/Conversion/Passes.h.inc"
namespace arith { namespace arith {
void populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter &converter, void populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns); RewritePatternSet &patterns);
std::unique_ptr<Pass> createConvertArithmeticToLLVMPass();
} // namespace arith } // namespace arith
} // namespace mlir } // namespace mlir

View File

@ -96,12 +96,11 @@ def ConvertAMDGPUToROCDL : Pass<"convert-amdgpu-to-rocdl"> {
// ArithmeticToLLVM // ArithmeticToLLVM
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def ConvertArithmeticToLLVM : Pass<"convert-arith-to-llvm"> { def ArithmeticToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
let summary = "Convert Arithmetic dialect to LLVM dialect"; let summary = "Convert Arithmetic dialect to LLVM dialect";
let description = [{ let description = [{
This pass converts supported Arithmetic ops to LLVM dialect instructions. This pass converts supported Arithmetic ops to LLVM dialect instructions.
}]; }];
let constructor = "mlir::arith::createConvertArithmeticToLLVMPass()";
let dependentDialects = ["LLVM::LLVMDialect"]; let dependentDialects = ["LLVM::LLVMDialect"];
let options = [ let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned", Option<"indexBitwidth", "index-bitwidth", "unsigned",

View File

@ -16,7 +16,7 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
#define GEN_PASS_DEF_CONVERTARITHMETICTOLLVM #define GEN_PASS_DEF_ARITHMETICTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc" #include "mlir/Conversion/Passes.h.inc"
} // namespace mlir } // namespace mlir
@ -320,9 +320,10 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace { namespace {
struct ConvertArithmeticToLLVMPass struct ArithmeticToLLVMConversionPass
: public impl::ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> { : public impl::ArithmeticToLLVMConversionPassBase<
ConvertArithmeticToLLVMPass() = default; ArithmeticToLLVMConversionPass> {
using Base::Base;
void runOnOperation() override { void runOnOperation() override {
LLVMConversionTarget target(getContext()); LLVMConversionTarget target(getContext());
@ -395,7 +396,3 @@ void mlir::arith::populateArithmeticToLLVMConversionPatterns(
>(converter); >(converter);
// clang-format on // clang-format on
} }
std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() {
return std::make_unique<ConvertArithmeticToLLVMPass>();
}

View File

@ -35,8 +35,8 @@ void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
MlirOpPassManager opm = mlirPassManagerGetNestedUnder( MlirOpPassManager opm = mlirPassManagerGetNestedUnder(
pm, mlirStringRefCreateFromCString("func.func")); pm, mlirStringRefCreateFromCString("func.func"));
mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVM()); mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVM());
mlirOpPassManagerAddOwnedPass(opm, mlirOpPassManagerAddOwnedPass(
mlirCreateConversionConvertArithmeticToLLVM()); opm, mlirCreateConversionArithmeticToLLVMConversionPass());
MlirLogicalResult status = mlirPassManagerRun(pm, module); MlirLogicalResult status = mlirPassManagerRun(pm, module);
if (mlirLogicalResultIsFailure(status)) { if (mlirLogicalResultIsFailure(status)) {
fprintf(stderr, "Unexpected failure running pass pipeline\n"); fprintf(stderr, "Unexpected failure running pass pipeline\n");

View File

@ -54,8 +54,7 @@ static struct LLVMInitializer {
static LogicalResult lowerToLLVMDialect(ModuleOp module) { static LogicalResult lowerToLLVMDialect(ModuleOp module) {
PassManager pm(module.getContext()); PassManager pm(module.getContext());
pm.addPass(mlir::createMemRefToLLVMConversionPass()); pm.addPass(mlir::createMemRefToLLVMConversionPass());
pm.addNestedPass<func::FuncOp>( pm.addNestedPass<func::FuncOp>(mlir::createArithmeticToLLVMConversionPass());
mlir::arith::createConvertArithmeticToLLVMPass());
pm.addPass(mlir::createConvertFuncToLLVMPass()); pm.addPass(mlir::createConvertFuncToLLVMPass());
pm.addPass(mlir::createReconcileUnrealizedCastsPass()); pm.addPass(mlir::createReconcileUnrealizedCastsPass());
return pm.run(module); return pm.run(module);