diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h index 480401410dbf..5501158eaabb 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h @@ -93,8 +93,8 @@ inline void cleanupAndPrintFunction(mlir::FuncOp f) { } }; auto pm = cleanupPassManager(); - check(mlir::verify(f.getModule())); - check(pm->run(f.getModule())); + check(mlir::verify(f.getParentOfType())); + check(pm->run(f.getParentOfType())); if (printToOuts) f.print(llvm::outs()); } diff --git a/mlir/examples/Linalg/Linalg3/Example.cpp b/mlir/examples/Linalg/Linalg3/Example.cpp index e9b4e68e4f3b..e596ddabcc99 100644 --- a/mlir/examples/Linalg/Linalg3/Example.cpp +++ b/mlir/examples/Linalg/Linalg3/Example.cpp @@ -173,7 +173,7 @@ TEST_FUNC(matmul_as_matvec_as_affine) { lowerToLoops(f); PassManager pm; pm.addPass(createLowerLinalgLoadStorePass()); - if (succeeded(pm.run(f.getModule()))) + if (succeeded(pm.run(f.getParentOfType()))) cleanupAndPrintFunction(f); // clang-format off diff --git a/mlir/examples/Linalg/Linalg4/Example.cpp b/mlir/examples/Linalg/Linalg4/Example.cpp index 4455d21e2746..ed439f2f3478 100644 --- a/mlir/examples/Linalg/Linalg4/Example.cpp +++ b/mlir/examples/Linalg/Linalg4/Example.cpp @@ -69,7 +69,7 @@ TEST_FUNC(matmul_tiled_loops) { lowerToTiledLoops(f, {8, 9}); PassManager pm; pm.addPass(createLowerLinalgLoadStorePass()); - if (succeeded(pm.run(f.getModule()))) + if (succeeded(pm.run(f.getParentOfType()))) cleanupAndPrintFunction(f); // clang-format off diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index ae13328bf7f6..c3e93a5b5fa8 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -27,8 +27,6 @@ #include "llvm/ADT/SmallString.h" namespace mlir { -class ModuleOp; - //===--------------------------------------------------------------------===// // Function Operation. //===--------------------------------------------------------------------===// @@ -60,9 +58,6 @@ public: FunctionType type, ArrayRef attrs, ArrayRef argAttrs); - /// Get the parent module. - ModuleOp getModule(); - /// Operation hooks. static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 5799857936ad..0b876c5dfb1d 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -441,14 +441,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern { createIndexConstant(rewriter, op->getLoc(), elementSize)}); // Insert the `malloc` declaration if it is not already present. - FuncOp mallocFunc = - op->getParentOfType().getModule().getNamedFunction("malloc"); + auto module = op->getParentOfType(); + FuncOp mallocFunc = module.getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(getIndexType(), getVoidPtrType()); mallocFunc = FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType); - op->getParentOfType().getModule().push_back(mallocFunc); + module.push_back(mallocFunc); } // Allocate the underlying buffer and store a pointer to it in the MemRef @@ -503,12 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { OperandAdaptor transformed(operands); // Insert the `free` declaration if it is not already present. - FuncOp freeFunc = - op->getParentOfType().getModule().getNamedFunction("free"); + FuncOp freeFunc = op->getParentOfType().getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); - op->getParentOfType().getModule().push_back(freeFunc); + op->getParentOfType().push_back(freeFunc); } auto type = transformed.memref()->getType().cast(); diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index b471010d0af9..d34f4b4044e1 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -73,12 +73,6 @@ void FuncOp::build(Builder *builder, OperationState *result, StringRef name, result->addAttribute(getArgAttrName(i, argAttrName), argDict); } -/// Get the parent module. -ModuleOp FuncOp::getModule() { - auto *parent = getOperation()->getContainingRegion(); - return parent ? parent->getParentOfType() : nullptr; -} - /// Parsing/Printing methods. static ParseResult parseArgumentList(OpAsmParser *parser, SmallVectorImpl &argTypes, diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 0cda24722e29..d43c1cf46c87 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -575,7 +575,7 @@ public: // types and returns pointers to the output types. static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) { auto implFnName = (libFn.getName().str() + "_impl"); - auto module = libFn.getModule(); + auto module = libFn.getParentOfType(); if (auto f = module.getNamedFunction(implFnName)) { return f; } diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 8ba2169b0210..4d07d911ab13 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -66,7 +66,7 @@ static void printIR(const llvm::Any &ir, bool printModuleScope, // Print the function name and a newline before the Module. out << " (function: " << function.getName() << ")\n"; - function.getModule().print(out); + function.getParentOfType().print(out); return; } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index e8c92e56d09f..125fcd37cbb7 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -340,7 +340,7 @@ PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const { /// Create an analysis slice for the given child function. FunctionAnalysisManager ModuleAnalysisManager::slice(FuncOp func) { - assert(func.getModule() == moduleAnalyses.getIRUnit() && + assert(func.getOperation()->getParentOp() == moduleAnalyses.getIRUnit() && "function has a different parent module"); auto it = functionAnalyses.find(func); if (it == functionAnalyses.end()) {