forked from OSchip/llvm-project
NFC: Refactor Module to be value typed.
As with Functions, Module will soon become an operation, which are value-typed. This eases the transition from Module to ModuleOp. A new class, OwningModuleRef is provided to allow for owning a reference to a Module, and will auto-delete the held module on destruction. PiperOrigin-RevId: 256196193
This commit is contained in:
parent
b4a2dbc8b6
commit
206e55cc16
|
@ -146,8 +146,8 @@ struct PythonFunction {
|
|||
/// Trivial C++ wrappers make use of the EDSC C API.
|
||||
struct PythonMLIRModule {
|
||||
PythonMLIRModule()
|
||||
: mlirContext(), module(new mlir::Module(&mlirContext)),
|
||||
moduleManager(module.get()) {}
|
||||
: mlirContext(), module(mlir::Module::create(&mlirContext)),
|
||||
moduleManager(*module) {}
|
||||
|
||||
PythonType makeScalarType(const std::string &mlirElemType,
|
||||
unsigned bitwidth) {
|
||||
|
@ -197,12 +197,12 @@ struct PythonMLIRModule {
|
|||
manager.addPass(mlir::createCSEPass());
|
||||
manager.addPass(mlir::createLowerAffinePass());
|
||||
manager.addPass(mlir::createConvertToLLVMIRPass());
|
||||
if (failed(manager.run(module.get()))) {
|
||||
if (failed(manager.run(*module))) {
|
||||
llvm::errs() << "conversion to the LLVM IR dialect failed\n";
|
||||
return;
|
||||
}
|
||||
|
||||
auto created = mlir::ExecutionEngine::create(module.get());
|
||||
auto created = mlir::ExecutionEngine::create(*module);
|
||||
llvm::handleAllErrors(created.takeError(),
|
||||
[](const llvm::ErrorInfoBase &b) {
|
||||
b.log(llvm::errs());
|
||||
|
@ -235,7 +235,7 @@ struct PythonMLIRModule {
|
|||
private:
|
||||
mlir::MLIRContext mlirContext;
|
||||
// One single module in a python-exposed MLIRContext for now.
|
||||
std::unique_ptr<mlir::Module> module;
|
||||
mlir::OwningModuleRef module;
|
||||
mlir::ModuleManager moduleManager;
|
||||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||
};
|
||||
|
|
|
@ -57,7 +57,7 @@ inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context,
|
|||
}
|
||||
|
||||
/// A basic function builder
|
||||
inline mlir::Function makeFunction(mlir::Module &module, llvm::StringRef name,
|
||||
inline mlir::Function makeFunction(mlir::Module module, llvm::StringRef name,
|
||||
llvm::ArrayRef<mlir::Type> types,
|
||||
llvm::ArrayRef<mlir::Type> resultTypes) {
|
||||
auto *context = module.getContext();
|
||||
|
@ -92,7 +92,7 @@ inline void cleanupAndPrintFunction(mlir::Function f) {
|
|||
}
|
||||
};
|
||||
auto pm = cleanupPassManager();
|
||||
check(f.getModule()->verify());
|
||||
check(f.getModule().verify());
|
||||
check(pm->run(f.getModule()));
|
||||
if (printToOuts)
|
||||
f.print(llvm::outs());
|
||||
|
|
|
@ -51,7 +51,7 @@ void populateLinalg1ToLLVMConversionPatterns(
|
|||
/// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations
|
||||
/// to the LLVM IR dialect types and operations in the given `module`. This is
|
||||
/// the main entry point to the conversion.
|
||||
void convertToLLVM(mlir::Module &module);
|
||||
void convertToLLVM(mlir::Module module);
|
||||
} // end namespace linalg
|
||||
|
||||
#endif // LINALG1_CONVERTTOLLVMDIALECT_H_
|
||||
|
|
|
@ -406,11 +406,11 @@ struct LinalgTypeConverter : public LLVMTypeConverter {
|
|||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void linalg::convertToLLVM(mlir::Module &module) {
|
||||
void linalg::convertToLLVM(mlir::Module module) {
|
||||
// Remove affine constructs if any by using an existing pass.
|
||||
PassManager pm;
|
||||
pm.addPass(createLowerAffinePass());
|
||||
auto rr = pm.run(&module);
|
||||
auto rr = pm.run(module);
|
||||
(void)rr;
|
||||
assert(succeeded(rr) && "affine loop lowering failed");
|
||||
|
||||
|
|
|
@ -34,10 +34,10 @@ using namespace linalg::intrinsics;
|
|||
|
||||
TEST_FUNC(linalg_ops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
OwningModuleRef module = Module::create(&context);
|
||||
auto indexType = mlir::IndexType::get(&context);
|
||||
mlir::Function f =
|
||||
makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {});
|
||||
mlir::Function f = makeFunction(*module, "linalg_ops",
|
||||
{indexType, indexType, indexType}, {});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
@ -73,9 +73,9 @@ TEST_FUNC(linalg_ops) {
|
|||
|
||||
TEST_FUNC(linalg_ops_folded_slices) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
OwningModuleRef module = Module::create(&context);
|
||||
auto indexType = mlir::IndexType::get(&context);
|
||||
mlir::Function f = makeFunction(module, "linalg_ops_folded_slices",
|
||||
mlir::Function f = makeFunction(*module, "linalg_ops_folded_slices",
|
||||
{indexType, indexType, indexType}, {});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
|
|
|
@ -37,7 +37,7 @@ using namespace linalg;
|
|||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::Function f = linalg::common::makeFunction(
|
||||
|
@ -66,11 +66,11 @@ Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
|||
|
||||
TEST_FUNC(foo) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
OwningModuleRef module = Module::create(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops");
|
||||
lowerToLoops(f);
|
||||
|
||||
convertLinalg3ToLLVM(module);
|
||||
convertLinalg3ToLLVM(*module);
|
||||
|
||||
// clang-format off
|
||||
// CHECK: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
@ -104,7 +104,7 @@ TEST_FUNC(foo) {
|
|||
// CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.store {{.*}}, {{.*}} : !llvm<"float*">
|
||||
// clang-format on
|
||||
module.print(llvm::outs());
|
||||
module->print(llvm::outs());
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
|
|
@ -34,7 +34,7 @@ using namespace linalg;
|
|||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::Function f = linalg::common::makeFunction(
|
||||
|
@ -63,7 +63,7 @@ Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
|||
|
||||
TEST_FUNC(matmul_as_matvec) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
Module module = Module::create(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
composeSliceOps(f);
|
||||
|
@ -81,7 +81,7 @@ TEST_FUNC(matmul_as_matvec) {
|
|||
|
||||
TEST_FUNC(matmul_as_dot) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
Module module = Module::create(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
|
@ -102,7 +102,7 @@ TEST_FUNC(matmul_as_dot) {
|
|||
|
||||
TEST_FUNC(matmul_as_loops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
Module module = Module::create(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
lowerToLoops(f);
|
||||
composeSliceOps(f);
|
||||
|
@ -134,7 +134,7 @@ TEST_FUNC(matmul_as_loops) {
|
|||
|
||||
TEST_FUNC(matmul_as_matvec_as_loops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
Module module = Module::create(&context);
|
||||
mlir::Function f =
|
||||
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
|
@ -165,7 +165,7 @@ TEST_FUNC(matmul_as_matvec_as_loops) {
|
|||
|
||||
TEST_FUNC(matmul_as_matvec_as_affine) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
Module module = Module::create(&context);
|
||||
mlir::Function f =
|
||||
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
|
|
|
@ -37,7 +37,7 @@ using namespace linalg;
|
|||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::Function f = linalg::common::makeFunction(
|
||||
|
@ -109,14 +109,14 @@ TEST_FUNC(execution) {
|
|||
// linalg.matmul operation and lower it all the way down to the LLVM IR
|
||||
// dialect through partial conversions.
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
OwningModuleRef module = Module::create(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops");
|
||||
lowerToLoops(f);
|
||||
convertLinalg3ToLLVM(module);
|
||||
convertLinalg3ToLLVM(*module);
|
||||
|
||||
// Create an MLIR execution engine. The execution engine eagerly JIT-compiles
|
||||
// the module.
|
||||
auto maybeEngine = mlir::ExecutionEngine::create(&module);
|
||||
auto maybeEngine = mlir::ExecutionEngine::create(*module);
|
||||
assert(maybeEngine && "failed to construct an execution engine");
|
||||
auto &engine = maybeEngine.get();
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ class Module;
|
|||
} // end namespace mlir
|
||||
|
||||
namespace linalg {
|
||||
void convertLinalg3ToLLVM(mlir::Module &module);
|
||||
void convertLinalg3ToLLVM(mlir::Module module);
|
||||
} // end namespace linalg
|
||||
|
||||
#endif // LINALG3_CONVERTTOLLVMDIALECT_H_
|
||||
|
|
|
@ -146,7 +146,7 @@ static void populateLinalg3ToLLVMConversionPatterns(
|
|||
context);
|
||||
}
|
||||
|
||||
void linalg::convertLinalg3ToLLVM(Module &module) {
|
||||
void linalg::convertLinalg3ToLLVM(Module module) {
|
||||
// Remove affine constructs.
|
||||
for (auto func : module) {
|
||||
auto rr = lowerAffineConstructs(func);
|
||||
|
|
|
@ -34,7 +34,7 @@ using namespace linalg;
|
|||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::Function f = linalg::common::makeFunction(
|
||||
|
@ -64,8 +64,8 @@ Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
|||
|
||||
TEST_FUNC(matmul_tiled_loops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops");
|
||||
OwningModuleRef module = Module::create(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_loops");
|
||||
lowerToTiledLoops(f, {8, 9});
|
||||
PassManager pm;
|
||||
pm.addPass(createLowerLinalgLoadStorePass());
|
||||
|
@ -95,8 +95,8 @@ TEST_FUNC(matmul_tiled_loops) {
|
|||
|
||||
TEST_FUNC(matmul_tiled_views) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views");
|
||||
OwningModuleRef module = Module::create(&context);
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_views");
|
||||
OpBuilder b(f.getBody());
|
||||
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
|
||||
b.create<ConstantIndexOp>(f.getLoc(), 9)});
|
||||
|
@ -124,9 +124,9 @@ TEST_FUNC(matmul_tiled_views) {
|
|||
|
||||
TEST_FUNC(matmul_tiled_views_as_loops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
OwningModuleRef module = Module::create(&context);
|
||||
mlir::Function f =
|
||||
makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops");
|
||||
makeFunctionWithAMatmulOp(*module, "matmul_tiled_views_as_loops");
|
||||
OpBuilder b(f.getBody());
|
||||
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
|
||||
b.create<ConstantIndexOp>(f.getLoc(), 9)});
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class Module;
|
||||
class OwningModuleRef;
|
||||
} // namespace mlir
|
||||
|
||||
namespace toy {
|
||||
|
@ -35,8 +35,7 @@ class ModuleAST;
|
|||
|
||||
/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
|
||||
/// or nullptr on failure.
|
||||
std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST);
|
||||
mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
|
||||
} // namespace toy
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
|
||||
|
|
|
@ -66,22 +66,22 @@ public:
|
|||
|
||||
/// Public API: convert the AST for a Toy module (source file) to an MLIR
|
||||
/// Module.
|
||||
std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
|
||||
mlir::Module mlirGen(ModuleAST &moduleAST) {
|
||||
// We create an empty MLIR module and codegen functions one at a time and
|
||||
// add them to the module.
|
||||
theModule = make_unique<mlir::Module>(&context);
|
||||
theModule = mlir::Module::create(&context);
|
||||
|
||||
for (FunctionAST &F : moduleAST) {
|
||||
auto func = mlirGen(F);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule->push_back(func);
|
||||
theModule.push_back(func);
|
||||
}
|
||||
|
||||
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
|
||||
// this won't do much, but it should at least check some structural
|
||||
// properties.
|
||||
if (failed(theModule->verify())) {
|
||||
if (failed(theModule.verify())) {
|
||||
emitError(mlir::UnknownLoc::get(&context), "Module verification error");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ private:
|
|||
mlir::MLIRContext &context;
|
||||
|
||||
/// A "module" matches a source file: it contains a list of functions.
|
||||
std::unique_ptr<mlir::Module> theModule;
|
||||
mlir::Module theModule;
|
||||
|
||||
/// The builder is a helper class to create IR inside a function. It is
|
||||
/// re-initialized every time we enter a function and kept around as a
|
||||
|
@ -500,8 +500,8 @@ private:
|
|||
namespace toy {
|
||||
|
||||
// The public API for codegen.
|
||||
std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST) {
|
||||
mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST) {
|
||||
return MLIRGenImpl(context).mlirGen(moduleAST);
|
||||
}
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
|||
|
||||
int dumpMLIR() {
|
||||
mlir::MLIRContext context;
|
||||
std::unique_ptr<mlir::Module> module;
|
||||
mlir::OwningModuleRef module;
|
||||
if (inputType == InputType::MLIR ||
|
||||
llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
|
@ -86,7 +86,7 @@ int dumpMLIR() {
|
|||
}
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
module.reset(mlir::parseSourceFile(sourceMgr, &context));
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
if (!module) {
|
||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||
return 3;
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class Module;
|
||||
class OwningModuleRef;
|
||||
} // namespace mlir
|
||||
|
||||
namespace toy {
|
||||
|
@ -35,8 +35,7 @@ class ModuleAST;
|
|||
|
||||
/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
|
||||
/// or nullptr on failure.
|
||||
std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST);
|
||||
mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
|
||||
} // namespace toy
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
|
||||
|
|
|
@ -67,10 +67,10 @@ public:
|
|||
|
||||
/// Public API: convert the AST for a Toy module (source file) to an MLIR
|
||||
/// Module.
|
||||
std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
|
||||
mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
|
||||
// We create an empty MLIR module and codegen functions one at a time and
|
||||
// add them to the module.
|
||||
theModule = make_unique<mlir::Module>(&context);
|
||||
theModule = mlir::Module::create(&context);
|
||||
|
||||
for (FunctionAST &F : moduleAST) {
|
||||
auto func = mlirGen(F);
|
||||
|
@ -97,7 +97,7 @@ private:
|
|||
mlir::MLIRContext &context;
|
||||
|
||||
/// A "module" matches a source file: it contains a list of functions.
|
||||
std::unique_ptr<mlir::Module> theModule;
|
||||
mlir::OwningModuleRef theModule;
|
||||
|
||||
/// The builder is a helper class to create IR inside a function. It is
|
||||
/// re-initialized every time we enter a function and kept around as a
|
||||
|
@ -469,8 +469,8 @@ private:
|
|||
namespace toy {
|
||||
|
||||
// The public API for codegen.
|
||||
std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST) {
|
||||
mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST) {
|
||||
return MLIRGenImpl(context).mlirGen(moduleAST);
|
||||
}
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ int dumpMLIR() {
|
|||
mlir::registerDialect<ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
std::unique_ptr<mlir::Module> module;
|
||||
mlir::OwningModuleRef module;
|
||||
if (inputType == InputType::MLIR ||
|
||||
llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
|
@ -90,7 +90,7 @@ int dumpMLIR() {
|
|||
}
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
module.reset(mlir::parseSourceFile(sourceMgr, &context));
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
if (!module) {
|
||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||
return 3;
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class Module;
|
||||
class OwningModuleRef;
|
||||
} // namespace mlir
|
||||
|
||||
namespace toy {
|
||||
|
@ -35,8 +35,7 @@ class ModuleAST;
|
|||
|
||||
/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
|
||||
/// or nullptr on failure.
|
||||
std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST);
|
||||
mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
|
||||
} // namespace toy
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
|
||||
|
|
|
@ -67,10 +67,10 @@ public:
|
|||
|
||||
/// Public API: convert the AST for a Toy module (source file) to an MLIR
|
||||
/// Module.
|
||||
std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
|
||||
mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
|
||||
// We create an empty MLIR module and codegen functions one at a time and
|
||||
// add them to the module.
|
||||
theModule = make_unique<mlir::Module>(&context);
|
||||
theModule = mlir::Module::create(&context);
|
||||
|
||||
for (FunctionAST &F : moduleAST) {
|
||||
auto func = mlirGen(F);
|
||||
|
@ -97,7 +97,7 @@ private:
|
|||
mlir::MLIRContext &context;
|
||||
|
||||
/// A "module" matches a source file: it contains a list of functions.
|
||||
std::unique_ptr<mlir::Module> theModule;
|
||||
mlir::OwningModuleRef theModule;
|
||||
|
||||
/// The builder is a helper class to create IR inside a function. It is
|
||||
/// re-initialized every time we enter a function and kept around as a
|
||||
|
@ -469,8 +469,8 @@ private:
|
|||
namespace toy {
|
||||
|
||||
// The public API for codegen.
|
||||
std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST) {
|
||||
mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST) {
|
||||
return MLIRGenImpl(context).mlirGen(moduleAST);
|
||||
}
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ public:
|
|||
};
|
||||
|
||||
void runOnModule() override {
|
||||
auto &module = getModule();
|
||||
auto module = getModule();
|
||||
auto main = module.getNamedFunction("main");
|
||||
if (!main) {
|
||||
emitError(mlir::UnknownLoc::get(module.getContext()),
|
||||
|
|
|
@ -78,7 +78,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
|||
return parser.ParseModule();
|
||||
}
|
||||
|
||||
mlir::LogicalResult optimize(mlir::Module &module) {
|
||||
mlir::LogicalResult optimize(mlir::Module module) {
|
||||
mlir::PassManager pm;
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(createShapeInferencePass());
|
||||
|
@ -86,7 +86,7 @@ mlir::LogicalResult optimize(mlir::Module &module) {
|
|||
// Apply any generic pass manager command line options.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
return pm.run(&module);
|
||||
return pm.run(module);
|
||||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
|
@ -97,7 +97,7 @@ int dumpMLIR() {
|
|||
mlir::registerPassManagerCLOptions();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
std::unique_ptr<mlir::Module> module;
|
||||
mlir::OwningModuleRef module;
|
||||
if (inputType == InputType::MLIR ||
|
||||
llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
|
@ -108,7 +108,7 @@ int dumpMLIR() {
|
|||
}
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
module.reset(mlir::parseSourceFile(sourceMgr, &context));
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
if (!module) {
|
||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||
return 3;
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class Module;
|
||||
class OwningModuleRef;
|
||||
} // namespace mlir
|
||||
|
||||
namespace toy {
|
||||
|
@ -35,8 +35,7 @@ class ModuleAST;
|
|||
|
||||
/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
|
||||
/// or nullptr on failure.
|
||||
std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST);
|
||||
mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
|
||||
} // namespace toy
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
|
||||
|
|
|
@ -136,7 +136,7 @@ public:
|
|||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Get or create the declaration of the printf function in the module.
|
||||
Function printfFunc = getPrintf(*op->getFunction().getModule());
|
||||
Function printfFunc = getPrintf(op->getFunction().getModule());
|
||||
|
||||
auto print = cast<toy::PrintOp>(op);
|
||||
auto loc = print.getLoc();
|
||||
|
@ -205,13 +205,13 @@ private:
|
|||
|
||||
/// Return the prototype declaration for printf in the module, create it if
|
||||
/// necessary.
|
||||
Function getPrintf(Module &module) const {
|
||||
Function getPrintf(Module module) const {
|
||||
auto printfFunc = module.getNamedFunction("printf");
|
||||
if (printfFunc)
|
||||
return printfFunc;
|
||||
|
||||
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
|
||||
Builder builder(&module);
|
||||
Builder builder(module);
|
||||
auto *dialect =
|
||||
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
|
||||
|
|
|
@ -67,10 +67,10 @@ public:
|
|||
|
||||
/// Public API: convert the AST for a Toy module (source file) to an MLIR
|
||||
/// Module.
|
||||
std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
|
||||
mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
|
||||
// We create an empty MLIR module and codegen functions one at a time and
|
||||
// add them to the module.
|
||||
theModule = make_unique<mlir::Module>(&context);
|
||||
theModule = mlir::Module::create(&context);
|
||||
|
||||
for (FunctionAST &F : moduleAST) {
|
||||
auto func = mlirGen(F);
|
||||
|
@ -97,7 +97,7 @@ private:
|
|||
mlir::MLIRContext &context;
|
||||
|
||||
/// A "module" matches a source file: it contains a list of functions.
|
||||
std::unique_ptr<mlir::Module> theModule;
|
||||
mlir::OwningModuleRef theModule;
|
||||
|
||||
/// The builder is a helper class to create IR inside a function. It is
|
||||
/// re-initialized every time we enter a function and kept around as a
|
||||
|
@ -469,8 +469,8 @@ private:
|
|||
namespace toy {
|
||||
|
||||
// The public API for codegen.
|
||||
std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST) {
|
||||
mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
|
||||
ModuleAST &moduleAST) {
|
||||
return MLIRGenImpl(context).mlirGen(moduleAST);
|
||||
}
|
||||
|
||||
|
|
|
@ -119,8 +119,8 @@ public:
|
|||
};
|
||||
|
||||
void runOnModule() override {
|
||||
auto &module = getModule();
|
||||
mlir::ModuleManager moduleManager(&module);
|
||||
auto module = getModule();
|
||||
mlir::ModuleManager moduleManager(module);
|
||||
auto main = moduleManager.getNamedFunction("main");
|
||||
if (!main) {
|
||||
emitError(mlir::UnknownLoc::get(module.getContext()),
|
||||
|
|
|
@ -101,7 +101,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
|||
return parser.ParseModule();
|
||||
}
|
||||
|
||||
mlir::LogicalResult optimize(mlir::Module &module) {
|
||||
mlir::LogicalResult optimize(mlir::Module module) {
|
||||
mlir::PassManager pm;
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(createShapeInferencePass());
|
||||
|
@ -111,10 +111,10 @@ mlir::LogicalResult optimize(mlir::Module &module) {
|
|||
// Apply any generic pass manager command line options.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
return pm.run(&module);
|
||||
return pm.run(module);
|
||||
}
|
||||
|
||||
mlir::LogicalResult lowerDialect(mlir::Module &module, bool OnlyLinalg) {
|
||||
mlir::LogicalResult lowerDialect(mlir::Module module, bool OnlyLinalg) {
|
||||
mlir::PassManager pm;
|
||||
pm.addPass(createEarlyLoweringPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
@ -127,14 +127,14 @@ mlir::LogicalResult lowerDialect(mlir::Module &module, bool OnlyLinalg) {
|
|||
// Apply any generic pass manager command line options.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
return pm.run(&module);
|
||||
return pm.run(module);
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::Module> loadFileAndProcessModule(
|
||||
mlir::OwningModuleRef loadFileAndProcessModule(
|
||||
mlir::MLIRContext &context, bool EnableLinalgLowering = false,
|
||||
bool EnableLLVMLowering = false, bool EnableOpt = false) {
|
||||
|
||||
std::unique_ptr<mlir::Module> module;
|
||||
mlir::OwningModuleRef module;
|
||||
if (inputType == InputType::MLIR ||
|
||||
llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
|
@ -145,7 +145,7 @@ std::unique_ptr<mlir::Module> loadFileAndProcessModule(
|
|||
}
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
module.reset(mlir::parseSourceFile(sourceMgr, &context));
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
if (!module) {
|
||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||
return nullptr;
|
||||
|
@ -252,7 +252,7 @@ int runJit() {
|
|||
// the module.
|
||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||
/* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0);
|
||||
auto maybeEngine = mlir::ExecutionEngine::create(module.get(), optPipeline);
|
||||
auto maybeEngine = mlir::ExecutionEngine::create(*module, optPipeline);
|
||||
assert(maybeEngine && "failed to construct an execution engine");
|
||||
auto &engine = maybeEngine.get();
|
||||
|
||||
|
|
|
@ -54,10 +54,10 @@ namespace {
|
|||
struct MyFunctionPass : public FunctionPass<MyFunctionPass> {
|
||||
void runOnFunction() override {
|
||||
// Get the current function being operated on.
|
||||
Function *f = getFunction();
|
||||
Function f = getFunction();
|
||||
|
||||
// Operate on the operations within the function.
|
||||
f->walk([](Operation *inst) {
|
||||
f.walk([](Operation *inst) {
|
||||
....
|
||||
});
|
||||
}
|
||||
|
@ -94,10 +94,10 @@ namespace {
|
|||
struct MyModulePass : public ModulePass<MyModulePass> {
|
||||
void runOnModule() override {
|
||||
// Get the current module being operated on.
|
||||
Module *m = getModule();
|
||||
Module m = getModule();
|
||||
|
||||
// Operate on the functions within the module.
|
||||
for (auto &func : *m) {
|
||||
for (auto func : m) {
|
||||
....
|
||||
}
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ struct MyFunctionAnalysis {
|
|||
/// An interesting module analysis.
|
||||
struct MyModuleAnalysis {
|
||||
// Compute this analysis with the provided module.
|
||||
MyModuleAnalysis(Module *module);
|
||||
MyModuleAnalysis(Module module);
|
||||
};
|
||||
|
||||
void MyFunctionPass::runOnFunction() {
|
||||
|
@ -181,7 +181,7 @@ void MyModulePass::runOnModule() {
|
|||
|
||||
// Query MyFunctionAnalysis for a child function of the current module. It
|
||||
// will be computed if it doesn't exist.
|
||||
auto *fn = &*getModule().begin();
|
||||
auto fn = *getModule().begin();
|
||||
MyFunctionAnalysis &myAnalysis = getFunctionAnalysis<MyFunctionAnalysis>(fn);
|
||||
}
|
||||
```
|
||||
|
@ -255,7 +255,7 @@ pm.addPass(new MyFunctionPass3());
|
|||
pm.addPass(new MyModulePass2());
|
||||
|
||||
// Run the pass manager on a module.
|
||||
Module *m = ...;
|
||||
Module m = ...;
|
||||
if (failed(pm.run(m)))
|
||||
... // One of the passes signaled a failure.
|
||||
```
|
||||
|
@ -384,7 +384,7 @@ unsigned domInfoCount;
|
|||
pm.addInstrumentation(new DominanceCounterInstrumentation(domInfoCount));
|
||||
|
||||
// Run the pass manager on a module.
|
||||
Module *m = ...;
|
||||
Module m = ...;
|
||||
if (failed(pm.run(m)))
|
||||
...
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ namespace LLVM {
|
|||
/// support different values coming from the same predecessor. If a block has
|
||||
/// another block as a successor more than once with different values, insert
|
||||
/// a new dummy block for LLVM PHI nodes to tell the sources apart.
|
||||
void ensureDistinctSuccessors(Module *m);
|
||||
void ensureDistinctSuccessors(Module m);
|
||||
} // namespace LLVM
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -62,7 +62,7 @@ public:
|
|||
/// If `sharedLibPaths` are provided, the underlying JIT-compilation will open
|
||||
/// and link the shared libraries for symbol resolution.
|
||||
static llvm::Expected<std::unique_ptr<ExecutionEngine>>
|
||||
create(Module *m, std::function<llvm::Error(llvm::Module *)> transformer = {},
|
||||
create(Module m, std::function<llvm::Error(llvm::Module *)> transformer = {},
|
||||
ArrayRef<StringRef> sharedLibPaths = {});
|
||||
|
||||
/// Looks up a packed-argument function with the given name and returns a
|
||||
|
|
|
@ -57,12 +57,12 @@ class UnitAttr;
|
|||
class Builder {
|
||||
public:
|
||||
explicit Builder(MLIRContext *context) : context(context) {}
|
||||
explicit Builder(Module *module);
|
||||
explicit Builder(Module module);
|
||||
|
||||
MLIRContext *getContext() const { return context; }
|
||||
|
||||
Identifier getIdentifier(StringRef str);
|
||||
Module *createModule();
|
||||
Module createModule();
|
||||
|
||||
// Locations.
|
||||
Location getUnknownLoc();
|
||||
|
|
|
@ -34,10 +34,12 @@ class MLIRContext;
|
|||
class Module;
|
||||
|
||||
namespace detail {
|
||||
class ModuleStorage;
|
||||
|
||||
/// This class represents all of the internal state of a Function. This allows
|
||||
/// for the Function class to be value typed.
|
||||
class FunctionStorage
|
||||
: public llvm::ilist_node_with_parent<FunctionStorage, Module> {
|
||||
: public llvm::ilist_node_with_parent<FunctionStorage, ModuleStorage> {
|
||||
FunctionStorage(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
FunctionStorage(Location location, StringRef name, FunctionType type,
|
||||
|
@ -47,7 +49,7 @@ class FunctionStorage
|
|||
Identifier name;
|
||||
|
||||
/// The module this function is embedded into.
|
||||
Module *module = nullptr;
|
||||
ModuleStorage *module = nullptr;
|
||||
|
||||
/// The source location the function was defined or derived from.
|
||||
Location location;
|
||||
|
@ -116,7 +118,7 @@ public:
|
|||
}
|
||||
|
||||
MLIRContext *getContext();
|
||||
Module *getModule() { return impl->module; }
|
||||
Module getModule();
|
||||
|
||||
/// Add an entry block to an empty function, and set up the block arguments
|
||||
/// to match the signature of the function.
|
||||
|
@ -541,7 +543,7 @@ struct ilist_traits<::mlir::detail::FunctionStorage>
|
|||
function_iterator first, function_iterator last);
|
||||
|
||||
private:
|
||||
mlir::Module *getContainingModule();
|
||||
mlir::detail::ModuleStorage *getContainingModule();
|
||||
};
|
||||
|
||||
// Functions hash just like pointers.
|
||||
|
|
|
@ -27,12 +27,46 @@
|
|||
#include "llvm/ADT/ilist.h"
|
||||
|
||||
namespace mlir {
|
||||
class Module;
|
||||
|
||||
namespace detail {
|
||||
class ModuleStorage {
|
||||
explicit ModuleStorage(MLIRContext *context) : context(context) {}
|
||||
|
||||
/// getSublistAccess() - Returns pointer to member of function list
|
||||
static llvm::iplist<FunctionStorage> ModuleStorage::*
|
||||
getSublistAccess(FunctionStorage *) {
|
||||
return &ModuleStorage::functions;
|
||||
}
|
||||
|
||||
/// The context attached to this module.
|
||||
MLIRContext *context;
|
||||
|
||||
/// This is the actual list of functions the module contains.
|
||||
llvm::iplist<FunctionStorage> functions;
|
||||
|
||||
friend Module;
|
||||
friend struct llvm::ilist_traits<FunctionStorage>;
|
||||
friend FunctionStorage;
|
||||
friend Function;
|
||||
};
|
||||
} // end namespace detail
|
||||
|
||||
class Module {
|
||||
public:
|
||||
explicit Module(MLIRContext *context) : context(context) {}
|
||||
Module(detail::ModuleStorage *impl = nullptr) : impl(impl) {}
|
||||
|
||||
MLIRContext *getContext() { return context; }
|
||||
/// Construct a new module object with the given context.
|
||||
static Module create(MLIRContext *context) {
|
||||
return new detail::ModuleStorage(context);
|
||||
}
|
||||
|
||||
MLIRContext *getContext() { return impl->context; }
|
||||
|
||||
/// Allow converting a Module to bool for null checks.
|
||||
operator bool() const { return impl; }
|
||||
bool operator==(Module other) const { return impl == other.impl; }
|
||||
bool operator!=(Module other) const { return !(*this == other); }
|
||||
|
||||
/// An iterator class used to iterate over the held functions.
|
||||
class iterator : public llvm::mapped_iterator<
|
||||
|
@ -56,14 +90,14 @@ public:
|
|||
llvm::iterator_range<iterator> getFunctions() { return {begin(), end()}; }
|
||||
|
||||
// Iteration over the functions in the module.
|
||||
iterator begin() { return functions.begin(); }
|
||||
iterator end() { return functions.end(); }
|
||||
Function front() { return &functions.front(); }
|
||||
Function back() { return &functions.back(); }
|
||||
iterator begin() { return impl->functions.begin(); }
|
||||
iterator end() { return impl->functions.end(); }
|
||||
Function front() { return &impl->functions.front(); }
|
||||
Function back() { return &impl->functions.back(); }
|
||||
|
||||
void push_back(Function fn) { functions.push_back(fn.impl); }
|
||||
void push_back(Function fn) { impl->functions.push_back(fn.impl); }
|
||||
void insert(iterator insertPt, Function fn) {
|
||||
functions.insert(insertPt.getCurrent(), fn.impl);
|
||||
impl->functions.insert(insertPt.getCurrent(), fn.impl);
|
||||
}
|
||||
|
||||
// Interfaces for working with the symbol table.
|
||||
|
@ -79,6 +113,7 @@ public:
|
|||
/// name exists. Function names never include the @ on them. Note: This
|
||||
/// performs a linear scan of held symbols.
|
||||
Function getNamedFunction(Identifier name) {
|
||||
auto &functions = impl->functions;
|
||||
auto it = llvm::find_if(functions, [name](detail::FunctionStorage &fn) {
|
||||
return Function(&fn).getName() == name;
|
||||
});
|
||||
|
@ -93,22 +128,27 @@ public:
|
|||
void print(raw_ostream &os);
|
||||
void dump();
|
||||
|
||||
/// Erase the current module.
|
||||
void erase() {
|
||||
assert(impl && "expected valid module");
|
||||
delete impl;
|
||||
}
|
||||
|
||||
/// Methods for supporting PointerLikeTypeTraits.
|
||||
const void *getAsOpaquePointer() const {
|
||||
return static_cast<const void *>(impl);
|
||||
}
|
||||
static Module getFromOpaquePointer(const void *pointer) {
|
||||
return reinterpret_cast<detail::ModuleStorage *>(
|
||||
const_cast<void *>(pointer));
|
||||
}
|
||||
|
||||
private:
|
||||
friend struct llvm::ilist_traits<detail::FunctionStorage>;
|
||||
friend detail::FunctionStorage;
|
||||
friend Function;
|
||||
|
||||
/// getSublistAccess() - Returns pointer to member of function list
|
||||
static llvm::iplist<detail::FunctionStorage> Module::*
|
||||
getSublistAccess(detail::FunctionStorage *) {
|
||||
return &Module::functions;
|
||||
}
|
||||
|
||||
/// The context attached to this module.
|
||||
MLIRContext *context;
|
||||
|
||||
/// This is the actual list of functions the module contains.
|
||||
llvm::iplist<detail::FunctionStorage> functions;
|
||||
/// The internal impl storage object.
|
||||
detail::ModuleStorage *impl = nullptr;
|
||||
};
|
||||
|
||||
/// A class used to manage the symbols held by a module. This class handles
|
||||
|
@ -116,7 +156,7 @@ private:
|
|||
/// efficent named lookup to held symbols.
|
||||
class ModuleManager {
|
||||
public:
|
||||
ModuleManager(Module *module) : module(module), symbolTable(module) {}
|
||||
ModuleManager(Module module) : module(module), symbolTable(module) {}
|
||||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names must never include the @ on them.
|
||||
|
@ -127,11 +167,11 @@ public:
|
|||
/// Insert a new symbol into the module, auto-renaming it as necessary.
|
||||
void insert(Function function) {
|
||||
symbolTable.insert(function);
|
||||
module->push_back(function);
|
||||
module.push_back(function);
|
||||
}
|
||||
void insert(Module::iterator insertPt, Function function) {
|
||||
symbolTable.insert(function);
|
||||
module->insert(insertPt, function);
|
||||
module.insert(insertPt, function);
|
||||
}
|
||||
|
||||
/// Remove the given symbol from the module symbol table and then erase it.
|
||||
|
@ -141,16 +181,53 @@ public:
|
|||
}
|
||||
|
||||
/// Return the internally held module.
|
||||
Module *getModule() const { return module; }
|
||||
Module getModule() const { return module; }
|
||||
|
||||
/// Return the context of the internal module.
|
||||
MLIRContext *getContext() const { return module->getContext(); }
|
||||
MLIRContext *getContext() const { return getModule().getContext(); }
|
||||
|
||||
private:
|
||||
Module *module;
|
||||
Module module;
|
||||
SymbolTable symbolTable;
|
||||
};
|
||||
|
||||
/// This class acts as an owning reference to a Module, and will automatically
|
||||
/// destory the held Module if valid.
|
||||
class OwningModuleRef {
|
||||
public:
|
||||
OwningModuleRef(std::nullptr_t = nullptr) {}
|
||||
OwningModuleRef(Module module) : module(module) {}
|
||||
OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {}
|
||||
~OwningModuleRef() {
|
||||
if (module)
|
||||
module.erase();
|
||||
}
|
||||
|
||||
// Assign from another module reference.
|
||||
OwningModuleRef &operator=(OwningModuleRef &&other) {
|
||||
if (module)
|
||||
module.erase();
|
||||
module = other.release();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Allow accessing the internal module.
|
||||
Module get() const { return module; }
|
||||
Module operator*() const { return module; }
|
||||
Module *operator->() { return &module; }
|
||||
explicit operator bool() const { return module; }
|
||||
|
||||
/// Release the referenced module.
|
||||
Module release() {
|
||||
Module released;
|
||||
std::swap(released, module);
|
||||
return released;
|
||||
}
|
||||
|
||||
private:
|
||||
Module module;
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Module Operation.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -196,4 +273,20 @@ public:
|
|||
|
||||
} // end namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
|
||||
/// Allow stealing the low bits of ModuleStorage.
|
||||
template <> struct PointerLikeTypeTraits<mlir::Module> {
|
||||
public:
|
||||
static inline void *getAsVoidPointer(mlir::Module I) {
|
||||
return const_cast<void *>(I.getAsOpaquePointer());
|
||||
}
|
||||
static inline mlir::Module getFromVoidPointer(void *P) {
|
||||
return mlir::Module::getFromOpaquePointer(P);
|
||||
}
|
||||
enum { NumLowBitsAvailable = 3 };
|
||||
};
|
||||
|
||||
} // end namespace llvm
|
||||
|
||||
#endif // MLIR_IR_MODULE_H
|
||||
|
|
|
@ -31,7 +31,7 @@ class MLIRContext;
|
|||
class SymbolTable {
|
||||
public:
|
||||
/// Build a symbol table with the symbols within the given module.
|
||||
SymbolTable(Module *module);
|
||||
SymbolTable(Module module);
|
||||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names never include the @ on them.
|
||||
|
|
|
@ -37,24 +37,24 @@ class Type;
|
|||
/// This parses the file specified by the indicated SourceMgr and returns an
|
||||
/// MLIR module if it was valid. If not, the error message is emitted through
|
||||
/// the error handler registered in the context, and a null pointer is returned.
|
||||
Module *parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
|
||||
Module parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
|
||||
|
||||
/// This parses the file specified by the indicated filename and returns an
|
||||
/// MLIR module if it was valid. If not, the error message is emitted through
|
||||
/// the error handler registered in the context, and a null pointer is returned.
|
||||
Module *parseSourceFile(llvm::StringRef filename, MLIRContext *context);
|
||||
Module parseSourceFile(llvm::StringRef filename, MLIRContext *context);
|
||||
|
||||
/// This parses the file specified by the indicated filename using the provided
|
||||
/// SourceMgr and returns an MLIR module if it was valid. If not, the error
|
||||
/// message is emitted through the error handler registered in the context, and
|
||||
/// a null pointer is returned.
|
||||
Module *parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr,
|
||||
MLIRContext *context);
|
||||
Module parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr,
|
||||
MLIRContext *context);
|
||||
|
||||
/// This parses the module string to a MLIR module if it was valid. If not, the
|
||||
/// error message is emitted through the error handler registered in the
|
||||
/// context, and a null pointer is returned.
|
||||
Module *parseSourceString(llvm::StringRef moduleStr, MLIRContext *context);
|
||||
Module parseSourceString(llvm::StringRef moduleStr, MLIRContext *context);
|
||||
|
||||
/// This parses a single MLIR type to an MLIR context if it was valid. If not,
|
||||
/// an error message is emitted through a new SourceMgrDiagnosticHandler
|
||||
|
|
|
@ -223,7 +223,7 @@ private:
|
|||
/// An analysis manager for a specific module instance.
|
||||
class ModuleAnalysisManager {
|
||||
public:
|
||||
ModuleAnalysisManager(Module *module, PassInstrumentor *passInstrumentor)
|
||||
ModuleAnalysisManager(Module module, PassInstrumentor *passInstrumentor)
|
||||
: moduleAnalyses(module), passInstrumentor(passInstrumentor) {}
|
||||
ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
|
||||
ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
|
||||
|
@ -273,7 +273,7 @@ private:
|
|||
functionAnalyses;
|
||||
|
||||
/// The analyses for the owning module.
|
||||
detail::AnalysisMap<Module *> moduleAnalyses;
|
||||
detail::AnalysisMap<Module> moduleAnalyses;
|
||||
|
||||
/// An optional instrumentation object.
|
||||
PassInstrumentor *passInstrumentor;
|
||||
|
|
|
@ -138,8 +138,7 @@ private:
|
|||
/// Pass to transform a module. Derived passes should not inherit from this
|
||||
/// class directly, and instead should use the CRTP ModulePass class.
|
||||
class ModulePassBase : public Pass {
|
||||
using PassStateT =
|
||||
detail::PassExecutionState<Module *, ModuleAnalysisManager>;
|
||||
using PassStateT = detail::PassExecutionState<Module, ModuleAnalysisManager>;
|
||||
|
||||
public:
|
||||
static bool classof(const Pass *pass) {
|
||||
|
@ -153,7 +152,7 @@ protected:
|
|||
virtual void runOnModule() = 0;
|
||||
|
||||
/// Return the current module being transformed.
|
||||
Module &getModule() { return *getPassState().irAndPassFailed.getPointer(); }
|
||||
Module getModule() { return getPassState().irAndPassFailed.getPointer(); }
|
||||
|
||||
/// Return the MLIR context for the current module being transformed.
|
||||
MLIRContext &getContext() { return *getModule().getContext(); }
|
||||
|
@ -172,7 +171,7 @@ protected:
|
|||
private:
|
||||
/// Forwarding function to execute this pass.
|
||||
LLVM_NODISCARD
|
||||
LogicalResult run(Module *module, ModuleAnalysisManager &mam);
|
||||
LogicalResult run(Module module, ModuleAnalysisManager &mam);
|
||||
|
||||
/// The current execution state for the pass.
|
||||
llvm::Optional<PassStateT> passState;
|
||||
|
|
|
@ -60,7 +60,7 @@ public:
|
|||
|
||||
/// Run the passes within this manager on the provided module.
|
||||
LLVM_NODISCARD
|
||||
LogicalResult run(Module *module);
|
||||
LogicalResult run(Module module);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Pipeline Building
|
||||
|
|
|
@ -38,7 +38,7 @@ class Module;
|
|||
/// from the registered LLVM IR dialect. In case of error, report it
|
||||
/// to the error handler registered with the MLIR context, if any (obtained from
|
||||
/// the MLIR module), and return `nullptr`.
|
||||
std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module &m);
|
||||
std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module m);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
|
||||
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
|
@ -48,7 +48,7 @@ namespace LLVM {
|
|||
class ModuleTranslation {
|
||||
public:
|
||||
template <typename T = ModuleTranslation>
|
||||
static std::unique_ptr<llvm::Module> translateModule(Module &m) {
|
||||
static std::unique_ptr<llvm::Module> translateModule(Module m) {
|
||||
auto llvmModule = prepareLLVMModule(m);
|
||||
|
||||
T translator(m);
|
||||
|
@ -63,17 +63,17 @@ protected:
|
|||
// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
|
||||
// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
|
||||
// LLVMContext, the LLVM IR module will be created in that context.
|
||||
explicit ModuleTranslation(Module &module) : mlirModule(module) {}
|
||||
explicit ModuleTranslation(Module module) : mlirModule(module) {}
|
||||
virtual ~ModuleTranslation() {}
|
||||
|
||||
virtual bool convertOperation(Operation &op, llvm::IRBuilder<> &builder);
|
||||
static std::unique_ptr<llvm::Module> prepareLLVMModule(Module &m);
|
||||
static std::unique_ptr<llvm::Module> prepareLLVMModule(Module m);
|
||||
|
||||
private:
|
||||
|
||||
bool convertFunctions();
|
||||
bool convertOneFunction(Function &func);
|
||||
void connectPHINodes(Function &func);
|
||||
bool convertOneFunction(Function func);
|
||||
void connectPHINodes(Function func);
|
||||
bool convertBlock(Block &bb, bool ignoreArguments);
|
||||
|
||||
template <typename Range>
|
||||
|
@ -83,7 +83,7 @@ private:
|
|||
Location loc);
|
||||
|
||||
// Original and translated module.
|
||||
Module &mlirModule;
|
||||
Module mlirModule;
|
||||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
|
||||
protected:
|
||||
|
|
|
@ -30,7 +30,6 @@ class Module;
|
|||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class Module;
|
||||
|
||||
/// Convert the given MLIR module into NVVM IR. This conversion requires the
|
||||
|
@ -38,7 +37,7 @@ class Module;
|
|||
/// from the registered LLVM IR dialect. In case of error, report it
|
||||
/// to the error handler registered with the MLIR context, if any (obtained from
|
||||
/// the MLIR module), and return `nullptr`.
|
||||
std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Module &m);
|
||||
std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Module m);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -339,7 +339,7 @@ private:
|
|||
/// conversion object. This function returns failure if a type conversion
|
||||
/// failed.
|
||||
LLVM_NODISCARD LogicalResult applyConversionPatterns(
|
||||
Module &module, ConversionTarget &target, TypeConverter &converter,
|
||||
Module module, ConversionTarget &target, TypeConverter &converter,
|
||||
OwningRewritePatternList &&patterns);
|
||||
|
||||
/// Convert the given functions with the provided conversion patterns. This
|
||||
|
|
|
@ -27,17 +27,17 @@
|
|||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class Module;
|
||||
class OwningModuleRef;
|
||||
|
||||
/// Interface of the function that translates a file to MLIR. The
|
||||
/// implementation should create a new MLIR Module in the given context and
|
||||
/// return a pointer to it, or a nullptr in case of any error.
|
||||
using TranslateToMLIRFunction =
|
||||
std::function<std::unique_ptr<Module>(llvm::StringRef, MLIRContext *)>;
|
||||
std::function<OwningModuleRef(llvm::StringRef, MLIRContext *)>;
|
||||
/// Interface of the function that translates MLIR to a different format and
|
||||
/// outputs the result to a file. The implementation should return "true" on
|
||||
/// error and "false" otherwise. It is allowed to modify the module.
|
||||
using TranslateFromMLIRFunction =
|
||||
std::function<bool(Module *, llvm::StringRef)>;
|
||||
using TranslateFromMLIRFunction = std::function<bool(Module, llvm::StringRef)>;
|
||||
|
||||
/// Use Translate[To|From]MLIRRegistration as a global initialiser that
|
||||
/// registers a function and associates it with name. This requires that a
|
||||
|
|
|
@ -139,7 +139,7 @@ LogicalResult
|
|||
GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) {
|
||||
Builder builder(function.getContext());
|
||||
|
||||
std::unique_ptr<Module> module(builder.createModule());
|
||||
OwningModuleRef module = builder.createModule();
|
||||
|
||||
// TODO(herhut): Also handle called functions.
|
||||
module->push_back(function.clone());
|
||||
|
@ -147,8 +147,9 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) {
|
|||
auto llvmModule = translateModuleToNVVMIR(*module);
|
||||
auto cubin = convertModuleToCubin(*llvmModule, function);
|
||||
|
||||
if (!cubin)
|
||||
if (!cubin) {
|
||||
return function.emitError("Translation to CUDA binary failed.");
|
||||
}
|
||||
|
||||
function.setAttr(kCubinAnnotation,
|
||||
builder.getStringAttr({cubin->data(), cubin->size()}));
|
||||
|
|
|
@ -152,8 +152,8 @@ private:
|
|||
// The types in comments give the actual types expected/returned but the API
|
||||
// uses void pointers. This is fine as they have the same linkage in C.
|
||||
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
||||
Module &module = getModule();
|
||||
Builder builder(&module);
|
||||
Module module = getModule();
|
||||
Builder builder(module);
|
||||
if (!module.getNamedFunction(cuModuleLoadName)) {
|
||||
module.push_back(
|
||||
Function::create(loc, cuModuleLoadName,
|
||||
|
@ -343,7 +343,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
ArrayRef<Value *>{cuModule, data.getResult(0)});
|
||||
// Get the function from the module. The name corresponds to the name of
|
||||
// the kernel function.
|
||||
auto cuModuleRef =
|
||||
auto cuOwningModuleRef =
|
||||
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
|
||||
auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
|
||||
auto cuFunction = allocatePointer(builder, loc);
|
||||
|
@ -352,7 +352,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getFunctionAttr(cuModuleGetFunction),
|
||||
ArrayRef<Value *>{cuFunction, cuModuleRef, kernelName});
|
||||
ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
|
||||
// Grab the global stream needed for execution.
|
||||
Function cuGetStreamHelper =
|
||||
getModule().getNamedFunction(cuGetStreamHelperName);
|
||||
|
|
|
@ -115,7 +115,7 @@ public:
|
|||
void runOnModule() override {
|
||||
llvmDialect =
|
||||
getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
auto &module = getModule();
|
||||
auto module = getModule();
|
||||
Builder builder(&getContext());
|
||||
|
||||
auto functions = module.getFunctions();
|
||||
|
|
|
@ -442,13 +442,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
Function mallocFunc =
|
||||
op->getFunction().getModule()->getNamedFunction("malloc");
|
||||
op->getFunction().getModule().getNamedFunction("malloc");
|
||||
if (!mallocFunc) {
|
||||
auto mallocType =
|
||||
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
|
||||
mallocFunc =
|
||||
Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
|
||||
op->getFunction().getModule()->push_back(mallocFunc);
|
||||
op->getFunction().getModule().push_back(mallocFunc);
|
||||
}
|
||||
|
||||
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
||||
|
@ -503,11 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
OperandAdaptor<DeallocOp> transformed(operands);
|
||||
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
Function freeFunc = op->getFunction().getModule()->getNamedFunction("free");
|
||||
Function freeFunc = op->getFunction().getModule().getNamedFunction("free");
|
||||
if (!freeFunc) {
|
||||
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
|
||||
freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
|
||||
op->getFunction().getModule()->push_back(freeFunc);
|
||||
op->getFunction().getModule().push_back(freeFunc);
|
||||
}
|
||||
|
||||
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
|
||||
|
@ -936,8 +936,8 @@ static void ensureDistinctSuccessors(Block &bb) {
|
|||
}
|
||||
}
|
||||
|
||||
void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
|
||||
for (auto f : *m) {
|
||||
void mlir::LLVM::ensureDistinctSuccessors(Module m) {
|
||||
for (auto f : m) {
|
||||
for (auto &bb : f.getBlocks()) {
|
||||
::ensureDistinctSuccessors(bb);
|
||||
}
|
||||
|
@ -1010,8 +1010,8 @@ namespace {
|
|||
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
||||
// Run the dialect converter on the module.
|
||||
void runOnModule() override {
|
||||
Module &m = getModule();
|
||||
LLVM::ensureDistinctSuccessors(&m);
|
||||
Module m = getModule();
|
||||
LLVM::ensureDistinctSuccessors(m);
|
||||
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
|
|
|
@ -322,7 +322,7 @@ void packFunctionArguments(llvm::Module *module) {
|
|||
ExecutionEngine::~ExecutionEngine() = default;
|
||||
|
||||
Expected<std::unique_ptr<ExecutionEngine>>
|
||||
ExecutionEngine::create(Module *m,
|
||||
ExecutionEngine::create(Module m,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer,
|
||||
ArrayRef<StringRef> sharedLibPaths) {
|
||||
auto engine = llvm::make_unique<ExecutionEngine>();
|
||||
|
@ -330,7 +330,7 @@ ExecutionEngine::create(Module *m,
|
|||
if (!expectedJIT)
|
||||
return expectedJIT.takeError();
|
||||
|
||||
auto llvmModule = translateModuleToLLVMIR(*m);
|
||||
auto llvmModule = translateModuleToLLVMIR(m);
|
||||
if (!llvmModule)
|
||||
return make_string_error("could not convert to LLVM IR");
|
||||
// FIXME: the triple should be passed to the translation or dialect conversion
|
||||
|
|
|
@ -426,8 +426,8 @@ LogicalResult LaunchFuncOp::verify() {
|
|||
return emitOpError("attribute 'kernel' must be a function");
|
||||
}
|
||||
|
||||
auto *module = getOperation()->getFunction().getModule();
|
||||
Function kernelFunc = module->getNamedFunction(kernel());
|
||||
auto module = getOperation()->getFunction().getModule();
|
||||
Function kernelFunc = module.getNamedFunction(kernel());
|
||||
if (!kernelFunc)
|
||||
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ namespace {
|
|||
class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
|
||||
public:
|
||||
void runOnModule() override {
|
||||
ModuleManager moduleManager(&getModule());
|
||||
ModuleManager moduleManager(getModule());
|
||||
for (auto func : getModule()) {
|
||||
func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
|
||||
Function outlinedFunc = outlineKernelFunc(op);
|
||||
|
|
|
@ -91,7 +91,7 @@ public:
|
|||
explicit ModuleState(MLIRContext *context) : context(context) {}
|
||||
|
||||
// Initializes module state, populating affine map state.
|
||||
void initialize(Module *module);
|
||||
void initialize(Module module);
|
||||
|
||||
Twine getAttributeAlias(Attribute attr) const {
|
||||
auto alias = attrToAlias.find(attr);
|
||||
|
@ -301,12 +301,12 @@ void ModuleState::initializeSymbolAliases() {
|
|||
}
|
||||
|
||||
// Initializes module state, populating affine map and integer set state.
|
||||
void ModuleState::initialize(Module *module) {
|
||||
void ModuleState::initialize(Module module) {
|
||||
// Initialize the symbol aliases.
|
||||
initializeSymbolAliases();
|
||||
|
||||
// Walk the module and visit each operation.
|
||||
for (auto fn : *module) {
|
||||
for (auto fn : module) {
|
||||
visitType(fn.getType());
|
||||
for (auto attr : fn.getAttrs())
|
||||
ModuleState::visitAttribute(attr.second);
|
||||
|
@ -331,7 +331,7 @@ public:
|
|||
interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
|
||||
}
|
||||
|
||||
void print(Module *module);
|
||||
void print(Module module);
|
||||
|
||||
/// Print the given attribute. If 'mayElideType' is true, some attributes are
|
||||
/// printed without the type when the type matches the default used in the
|
||||
|
@ -451,13 +451,13 @@ void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
|
|||
}
|
||||
}
|
||||
|
||||
void ModulePrinter::print(Module *module) {
|
||||
void ModulePrinter::print(Module module) {
|
||||
// Output the aliases at the top level.
|
||||
state.printAttributeAliases(os);
|
||||
state.printTypeAliases(os);
|
||||
|
||||
// Print the module.
|
||||
for (auto fn : *module)
|
||||
for (auto fn : module)
|
||||
print(fn);
|
||||
}
|
||||
|
||||
|
@ -1784,8 +1784,8 @@ void Function::dump() { print(llvm::errs()); }
|
|||
|
||||
void Module::print(raw_ostream &os) {
|
||||
ModuleState state(getContext());
|
||||
state.initialize(this);
|
||||
ModulePrinter(os, state).print(this);
|
||||
state.initialize(*this);
|
||||
ModulePrinter(os, state).print(*this);
|
||||
}
|
||||
|
||||
void Module::dump() { print(llvm::errs()); }
|
||||
|
|
|
@ -26,13 +26,13 @@
|
|||
#include "mlir/Support/Functional.h"
|
||||
using namespace mlir;
|
||||
|
||||
Builder::Builder(Module *module) : context(module->getContext()) {}
|
||||
Builder::Builder(Module module) : context(module.getContext()) {}
|
||||
|
||||
Identifier Builder::getIdentifier(StringRef str) {
|
||||
return Identifier::get(str, context);
|
||||
}
|
||||
|
||||
Module *Builder::createModule() { return new Module(context); }
|
||||
Module Builder::createModule() { return Module::create(context); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Locations.
|
||||
|
|
|
@ -43,12 +43,14 @@ FunctionStorage::FunctionStorage(Location location, StringRef name,
|
|||
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
|
||||
|
||||
MLIRContext *Function::getContext() { return getType().getContext(); }
|
||||
Module Function::getModule() { return impl->module; }
|
||||
|
||||
Module *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
|
||||
size_t Offset(
|
||||
size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
|
||||
ModuleStorage *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
|
||||
size_t Offset(size_t(
|
||||
&((ModuleStorage *)nullptr->*ModuleStorage::getSublistAccess(nullptr))));
|
||||
iplist<FunctionStorage> *Anchor(static_cast<iplist<FunctionStorage> *>(this));
|
||||
return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset);
|
||||
return reinterpret_cast<ModuleStorage *>(reinterpret_cast<char *>(Anchor) -
|
||||
Offset);
|
||||
}
|
||||
|
||||
/// This is a trait method invoked when a Function is added to a Module. We
|
||||
|
@ -74,7 +76,7 @@ void llvm::ilist_traits<FunctionStorage>::transferNodesFromList(
|
|||
function_iterator last) {
|
||||
// If we are transferring functions within the same module, the Module
|
||||
// pointer doesn't need to be updated.
|
||||
Module *curParent = getContainingModule();
|
||||
ModuleStorage *curParent = getContainingModule();
|
||||
if (curParent == otherList.getContainingModule())
|
||||
return;
|
||||
|
||||
|
@ -87,8 +89,8 @@ void llvm::ilist_traits<FunctionStorage>::transferNodesFromList(
|
|||
|
||||
/// Unlink this function from its Module and delete it.
|
||||
void Function::erase() {
|
||||
if (auto *module = getModule())
|
||||
getModule()->functions.erase(impl);
|
||||
if (auto module = getModule())
|
||||
module.impl->functions.erase(impl);
|
||||
else
|
||||
delete impl;
|
||||
}
|
||||
|
|
|
@ -21,8 +21,8 @@
|
|||
using namespace mlir;
|
||||
|
||||
/// Build a symbol table with the symbols within the given module.
|
||||
SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
|
||||
for (auto func : *module) {
|
||||
SymbolTable::SymbolTable(Module module) : context(module.getContext()) {
|
||||
for (auto func : module) {
|
||||
auto inserted = symbolTable.insert({func.getName(), func});
|
||||
(void)inserted;
|
||||
assert(inserted.second &&
|
||||
|
|
|
@ -170,13 +170,13 @@ public:
|
|||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto *module = op->getFunction().getModule();
|
||||
Function mallocFunc = module->getNamedFunction("malloc");
|
||||
auto module = op->getFunction().getModule();
|
||||
Function mallocFunc = module.getNamedFunction("malloc");
|
||||
if (!mallocFunc) {
|
||||
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
|
||||
mallocFunc =
|
||||
Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
|
||||
module->push_back(mallocFunc);
|
||||
module.push_back(mallocFunc);
|
||||
}
|
||||
|
||||
// Get MLIR types for injecting element pointer.
|
||||
|
@ -231,12 +231,12 @@ public:
|
|||
auto voidPtrTy =
|
||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
auto *module = op->getFunction().getModule();
|
||||
Function freeFunc = module->getNamedFunction("free");
|
||||
auto module = op->getFunction().getModule();
|
||||
Function freeFunc = module.getNamedFunction("free");
|
||||
if (!freeFunc) {
|
||||
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
|
||||
freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
|
||||
module->push_back(freeFunc);
|
||||
module.push_back(freeFunc);
|
||||
}
|
||||
|
||||
// Get MLIR types for extracting element pointer.
|
||||
|
@ -576,7 +576,7 @@ public:
|
|||
static Function getLLVMLibraryCallImplDefinition(Function libFn) {
|
||||
auto implFnName = (libFn.getName().str() + "_impl");
|
||||
auto module = libFn.getModule();
|
||||
if (auto f = module->getNamedFunction(implFnName)) {
|
||||
if (auto f = module.getNamedFunction(implFnName)) {
|
||||
return f;
|
||||
}
|
||||
SmallVector<Type, 4> fnArgTypes;
|
||||
|
@ -590,7 +590,7 @@ static Function getLLVMLibraryCallImplDefinition(Function libFn) {
|
|||
|
||||
// Insert the implementation function definition.
|
||||
auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType);
|
||||
module->push_back(implFnDefn);
|
||||
module.push_back(implFnDefn);
|
||||
return implFnDefn;
|
||||
}
|
||||
|
||||
|
@ -603,7 +603,7 @@ static Function getLLVMLibraryCallDeclaration(Operation *op,
|
|||
assert(isa<LinalgOp>(op));
|
||||
auto fnName = LinalgOp::getLibraryCallName();
|
||||
auto module = op->getFunction().getModule();
|
||||
if (auto f = module->getNamedFunction(fnName)) {
|
||||
if (auto f = module.getNamedFunction(fnName)) {
|
||||
return f;
|
||||
}
|
||||
|
||||
|
@ -620,7 +620,7 @@ static Function getLLVMLibraryCallDeclaration(Operation *op,
|
|||
"have void return types");
|
||||
auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
|
||||
auto libFn = Function::create(op->getLoc(), fnName, libFnType);
|
||||
module->push_back(libFn);
|
||||
module.push_back(libFn);
|
||||
// Return after creating the function definition. The body will be created
|
||||
// later.
|
||||
return libFn;
|
||||
|
@ -802,7 +802,7 @@ static void lowerLinalgForToCFG(Function &f) {
|
|||
}
|
||||
|
||||
void LowerLinalgToLLVMPass::runOnModule() {
|
||||
auto &module = getModule();
|
||||
auto module = getModule();
|
||||
|
||||
for (auto f : module.getFunctions()) {
|
||||
lowerLinalgSubViewOps(f);
|
||||
|
|
|
@ -3857,7 +3857,7 @@ class ModuleParser : public Parser {
|
|||
public:
|
||||
explicit ModuleParser(ParserState &state) : Parser(state) {}
|
||||
|
||||
ParseResult parseModule(Module *module);
|
||||
ParseResult parseModule(Module module);
|
||||
|
||||
private:
|
||||
/// Parse an attribute alias declaration.
|
||||
|
@ -3875,7 +3875,7 @@ private:
|
|||
StringRef &name, FunctionType &type,
|
||||
SmallVectorImpl<std::pair<SMLoc, StringRef>> &argNames,
|
||||
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
|
||||
ParseResult parseFunc(Module *module);
|
||||
ParseResult parseFunc(Module module);
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -4039,7 +4039,7 @@ ParseResult ModuleParser::parseFunctionSignature(
|
|||
/// function-body ::= `{` block+ `}`
|
||||
/// function-attributes ::= `attributes` attribute-dict
|
||||
///
|
||||
ParseResult ModuleParser::parseFunc(Module *module) {
|
||||
ParseResult ModuleParser::parseFunc(Module module) {
|
||||
consumeToken();
|
||||
|
||||
StringRef name;
|
||||
|
@ -4061,7 +4061,7 @@ ParseResult ModuleParser::parseFunc(Module *module) {
|
|||
// Okay, the function signature was parsed correctly, create the function now.
|
||||
auto function =
|
||||
Function::create(getEncodedSourceLocation(loc), name, type, attrs);
|
||||
module->push_back(function);
|
||||
module.push_back(function);
|
||||
|
||||
// Parse an optional trailing location.
|
||||
if (parseOptionalTrailingLocation(function))
|
||||
|
@ -4097,7 +4097,7 @@ ParseResult ModuleParser::parseFunc(Module *module) {
|
|||
}
|
||||
|
||||
/// This is the top-level module parser.
|
||||
ParseResult ModuleParser::parseModule(Module *module) {
|
||||
ParseResult ModuleParser::parseModule(Module module) {
|
||||
while (1) {
|
||||
switch (getToken().getKind()) {
|
||||
default:
|
||||
|
@ -4139,16 +4139,15 @@ ParseResult ModuleParser::parseModule(Module *module) {
|
|||
/// This parses the file specified by the indicated SourceMgr and returns an
|
||||
/// MLIR module if it was valid. If not, it emits diagnostics and returns
|
||||
/// null.
|
||||
Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
|
||||
MLIRContext *context) {
|
||||
Module mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
|
||||
MLIRContext *context) {
|
||||
|
||||
// This is the result module we are parsing into.
|
||||
std::unique_ptr<Module> module(new Module(context));
|
||||
OwningModuleRef module(Module::create(context));
|
||||
|
||||
ParserState state(sourceMgr, context);
|
||||
if (ModuleParser(state).parseModule(module.get())) {
|
||||
if (ModuleParser(state).parseModule(*module))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Make sure the parse module has no other structural problems detected by
|
||||
// the verifier.
|
||||
|
@ -4161,7 +4160,7 @@ Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
|
|||
/// This parses the file specified by the indicated filename and returns an
|
||||
/// MLIR module if it was valid. If not, the error message is emitted through
|
||||
/// the error handler registered in the context, and a null pointer is returned.
|
||||
Module *mlir::parseSourceFile(StringRef filename, MLIRContext *context) {
|
||||
Module mlir::parseSourceFile(StringRef filename, MLIRContext *context) {
|
||||
llvm::SourceMgr sourceMgr;
|
||||
return parseSourceFile(filename, sourceMgr, context);
|
||||
}
|
||||
|
@ -4170,8 +4169,8 @@ Module *mlir::parseSourceFile(StringRef filename, MLIRContext *context) {
|
|||
/// SourceMgr and returns an MLIR module if it was valid. If not, the error
|
||||
/// message is emitted through the error handler registered in the context, and
|
||||
/// a null pointer is returned.
|
||||
Module *mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr,
|
||||
MLIRContext *context) {
|
||||
Module mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr,
|
||||
MLIRContext *context) {
|
||||
if (sourceMgr.getNumBuffers() != 0) {
|
||||
// TODO(b/136086478): Extend to support multiple buffers.
|
||||
emitError(mlir::UnknownLoc::get(context),
|
||||
|
@ -4192,7 +4191,7 @@ Module *mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr,
|
|||
|
||||
/// This parses the program string to a MLIR module if it was valid. If not,
|
||||
/// it emits diagnostics and returns null.
|
||||
Module *mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) {
|
||||
Module mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) {
|
||||
auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr);
|
||||
if (!memBuffer)
|
||||
return nullptr;
|
||||
|
|
|
@ -66,7 +66,7 @@ static void printIR(const llvm::Any &ir, bool printModuleScope,
|
|||
|
||||
// Print the function name and a newline before the Module.
|
||||
out << " (function: " << function.getName() << ")\n";
|
||||
function.getModule()->print(out);
|
||||
function.getModule().print(out);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -80,8 +80,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope,
|
|||
}
|
||||
|
||||
// Print the given module.
|
||||
assert(llvm::any_isa<Module *>(ir) && "unexpected IR unit");
|
||||
llvm::any_cast<Module *>(ir)->print(out);
|
||||
assert(llvm::any_isa<Module>(ir) && "unexpected IR unit");
|
||||
llvm::any_cast<Module>(ir).print(out);
|
||||
}
|
||||
|
||||
/// Instrumentation hooks.
|
||||
|
|
|
@ -75,7 +75,7 @@ LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) {
|
|||
}
|
||||
|
||||
/// Forwarding function to execute this pass.
|
||||
LogicalResult ModulePassBase::run(Module *module, ModuleAnalysisManager &mam) {
|
||||
LogicalResult ModulePassBase::run(Module module, ModuleAnalysisManager &mam) {
|
||||
// Initialize the pass state.
|
||||
passState.emplace(module, mam);
|
||||
|
||||
|
@ -124,7 +124,7 @@ LogicalResult detail::FunctionPassExecutor::run(Function function,
|
|||
}
|
||||
|
||||
/// Run all of the passes in this manager over the current module.
|
||||
LogicalResult detail::ModulePassExecutor::run(Module *module,
|
||||
LogicalResult detail::ModulePassExecutor::run(Module module,
|
||||
ModuleAnalysisManager &mam) {
|
||||
// Run each of the held passes.
|
||||
for (auto &pass : passes)
|
||||
|
@ -261,7 +261,7 @@ PassManager::PassManager(bool verifyPasses)
|
|||
PassManager::~PassManager() {}
|
||||
|
||||
/// Run the passes within this manager on the provided module.
|
||||
LogicalResult PassManager::run(Module *module) {
|
||||
LogicalResult PassManager::run(Module module) {
|
||||
ModuleAnalysisManager mam(module, instrumentor.get());
|
||||
return mpe->run(module, mam);
|
||||
}
|
||||
|
|
|
@ -76,7 +76,7 @@ public:
|
|||
ModulePassExecutor &operator=(const ModulePassExecutor &) = delete;
|
||||
|
||||
/// Run the executor on the given module.
|
||||
LogicalResult run(Module *module, ModuleAnalysisManager &mam);
|
||||
LogicalResult run(Module module, ModuleAnalysisManager &mam);
|
||||
|
||||
/// Add a pass to the current executor. This takes ownership over the provided
|
||||
/// pass pointer.
|
||||
|
|
|
@ -34,10 +34,10 @@ using namespace mlir;
|
|||
|
||||
// Adds a one-block function named as `spirv_module` to `module` and returns the
|
||||
// block. The created block will be terminated by `std.return`.
|
||||
Block *createOneBlockFunction(Builder builder, Module *module) {
|
||||
Block *createOneBlockFunction(Builder builder, Module module) {
|
||||
auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{});
|
||||
auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType);
|
||||
module->push_back(fn);
|
||||
module.push_back(fn);
|
||||
|
||||
auto *block = new Block();
|
||||
fn.push_back(block);
|
||||
|
@ -51,8 +51,8 @@ Block *createOneBlockFunction(Builder builder, Module *module) {
|
|||
|
||||
// Deserializes the SPIR-V binary module stored in the file named as
|
||||
// `inputFilename` and returns a module containing the SPIR-V module.
|
||||
std::unique_ptr<Module> deserializeModule(llvm::StringRef inputFilename,
|
||||
MLIRContext *context) {
|
||||
OwningModuleRef deserializeModule(llvm::StringRef inputFilename,
|
||||
MLIRContext *context) {
|
||||
Builder builder(context);
|
||||
|
||||
std::string errorMessage;
|
||||
|
@ -83,7 +83,7 @@ std::unique_ptr<Module> deserializeModule(llvm::StringRef inputFilename,
|
|||
// converted SPIR-V ModuleOp inside a MLIR module. This should be changed to
|
||||
// return the SPIR-V ModuleOp directly after module and function are migrated
|
||||
// to be general ops.
|
||||
std::unique_ptr<Module> module(builder.createModule());
|
||||
OwningModuleRef module(builder.createModule());
|
||||
Block *block = createOneBlockFunction(builder, module.get());
|
||||
block->push_front(spirvModule->getOperation());
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
LogicalResult serializeModule(Module *module, StringRef outputFilename) {
|
||||
LogicalResult serializeModule(Module module, StringRef outputFilename) {
|
||||
if (!module)
|
||||
return failure();
|
||||
|
||||
|
@ -45,7 +45,7 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) {
|
|||
// wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed
|
||||
// to take in the SPIR-V ModuleOp directly after module and function are
|
||||
// migrated to be general ops.
|
||||
for (auto fn : *module) {
|
||||
for (auto fn : module) {
|
||||
fn.walk<spirv::ModuleOp>([&](spirv::ModuleOp spirvModule) {
|
||||
if (done) {
|
||||
spirvModule.emitError("found more than one 'spv.module' op");
|
||||
|
@ -73,6 +73,6 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) {
|
|||
|
||||
static TranslateFromMLIRRegistration
|
||||
registration("serialize-spirv",
|
||||
[](Module *module, StringRef outputFilename) {
|
||||
[](Module module, StringRef outputFilename) {
|
||||
return failed(serializeModule(module, outputFilename));
|
||||
});
|
||||
|
|
|
@ -440,7 +440,7 @@ static LogicalResult verify(CallOp op) {
|
|||
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return op.emitOpError("requires a 'callee' function attribute");
|
||||
auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
|
||||
auto fn = op.getOperation()->getFunction().getModule().getNamedFunction(
|
||||
fnAttr.getValue());
|
||||
if (!fn)
|
||||
return op.emitOpError() << "'" << fnAttr.getValue()
|
||||
|
@ -1107,7 +1107,7 @@ static LogicalResult verify(ConstantOp &op) {
|
|||
return op.emitOpError("requires 'value' to be a function reference");
|
||||
|
||||
// Try to find the referenced function.
|
||||
auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
|
||||
auto fn = op.getOperation()->getFunction().getModule().getNamedFunction(
|
||||
fnAttr.getValue());
|
||||
if (!fn)
|
||||
return op.emitOpError("reference to undefined function 'bar'");
|
||||
|
|
|
@ -50,7 +50,7 @@ static LogicalResult
|
|||
performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
|
||||
SourceMgr &sourceMgr, MLIRContext *context,
|
||||
const std::vector<const mlir::PassRegistryEntry *> &passList) {
|
||||
std::unique_ptr<Module> module(parseSourceFile(sourceMgr, context));
|
||||
OwningModuleRef module(parseSourceFile(sourceMgr, context));
|
||||
if (!module)
|
||||
return failure();
|
||||
|
||||
|
@ -63,7 +63,7 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
|
|||
applyPassManagerCLOptions(pm);
|
||||
|
||||
// Run the pipeline.
|
||||
if (failed(pm.run(module.get())))
|
||||
if (failed(pm.run(*module)))
|
||||
return failure();
|
||||
|
||||
// Print the output.
|
||||
|
|
|
@ -37,7 +37,7 @@ using namespace mlir;
|
|||
// Storage for the translation function wrappers that survive the parser.
|
||||
static llvm::SmallVector<TranslateFunction, 16> wrapperStorage;
|
||||
|
||||
static LogicalResult printMLIROutput(Module &module,
|
||||
static LogicalResult printMLIROutput(Module module,
|
||||
llvm::StringRef outputFilename) {
|
||||
if (failed(module.verify()))
|
||||
return failure();
|
||||
|
@ -62,7 +62,7 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
|
|||
TranslateFunction wrapper = [function](StringRef inputFilename,
|
||||
StringRef outputFilename,
|
||||
MLIRContext *context) {
|
||||
std::unique_ptr<Module> module = function(inputFilename, context);
|
||||
OwningModuleRef module = function(inputFilename, context);
|
||||
if (!module)
|
||||
return failure();
|
||||
return printMLIROutput(*module, outputFilename);
|
||||
|
@ -79,8 +79,8 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
|
|||
MLIRContext *context) {
|
||||
llvm::SourceMgr sourceMgr;
|
||||
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
|
||||
auto module = std::unique_ptr<Module>(
|
||||
parseSourceFile(inputFilename, sourceMgr, context));
|
||||
auto module =
|
||||
OwningModuleRef(parseSourceFile(inputFilename, sourceMgr, context));
|
||||
if (!module)
|
||||
return failure();
|
||||
return failure(function(module.get(), outputFilename));
|
||||
|
|
|
@ -31,16 +31,16 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module &m) {
|
||||
std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module m) {
|
||||
return LLVM::ModuleTranslation::translateModule<>(m);
|
||||
}
|
||||
|
||||
static TranslateFromMLIRRegistration registration(
|
||||
"mlir-to-llvmir", [](Module *module, llvm::StringRef outputFilename) {
|
||||
"mlir-to-llvmir", [](Module module, llvm::StringRef outputFilename) {
|
||||
if (!module)
|
||||
return true;
|
||||
|
||||
auto llvmModule = LLVM::ModuleTranslation::translateModule<>(*module);
|
||||
auto llvmModule = LLVM::ModuleTranslation::translateModule<>(module);
|
||||
if (!llvmModule)
|
||||
return true;
|
||||
|
||||
|
|
|
@ -47,8 +47,7 @@ static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
|
|||
class ModuleTranslation : public LLVM::ModuleTranslation {
|
||||
|
||||
public:
|
||||
explicit ModuleTranslation(Module &module)
|
||||
: LLVM::ModuleTranslation(module) {}
|
||||
explicit ModuleTranslation(Module module) : LLVM::ModuleTranslation(module) {}
|
||||
~ModuleTranslation() override {}
|
||||
|
||||
protected:
|
||||
|
@ -62,7 +61,7 @@ protected:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
|
||||
std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module m) {
|
||||
ModuleTranslation translation(m);
|
||||
auto llvmModule =
|
||||
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
|
||||
|
@ -91,11 +90,11 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
|
|||
|
||||
static TranslateFromMLIRRegistration
|
||||
registration("mlir-to-nvvmir",
|
||||
[](Module *module, llvm::StringRef outputFilename) {
|
||||
[](Module module, llvm::StringRef outputFilename) {
|
||||
if (!module)
|
||||
return true;
|
||||
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(*module);
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(module);
|
||||
if (!llvmModule)
|
||||
return true;
|
||||
|
||||
|
|
|
@ -275,7 +275,7 @@ static Value *getPHISourceValue(Block *current, Block *pred,
|
|||
: terminator.getSuccessorOperand(1, index);
|
||||
}
|
||||
|
||||
void ModuleTranslation::connectPHINodes(Function &func) {
|
||||
void ModuleTranslation::connectPHINodes(Function func) {
|
||||
// Skip the first block, it cannot be branched to and its arguments correspond
|
||||
// to the arguments of the LLVM function.
|
||||
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
|
||||
|
@ -306,7 +306,7 @@ static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
|
|||
}
|
||||
|
||||
// Sort function blocks topologically.
|
||||
static llvm::SetVector<Block *> topologicalSort(Function &f) {
|
||||
static llvm::SetVector<Block *> topologicalSort(Function f) {
|
||||
// For each blocks that has not been visited yet (i.e. that has no
|
||||
// predecessors), add it to the list and traverse its successors in DFS
|
||||
// preorder.
|
||||
|
@ -320,7 +320,7 @@ static llvm::SetVector<Block *> topologicalSort(Function &f) {
|
|||
return blocks;
|
||||
}
|
||||
|
||||
bool ModuleTranslation::convertOneFunction(Function &func) {
|
||||
bool ModuleTranslation::convertOneFunction(Function func) {
|
||||
// Clear the block and value mappings, they are only relevant within one
|
||||
// function.
|
||||
blockMapping.clear();
|
||||
|
@ -404,7 +404,7 @@ bool ModuleTranslation::convertFunctions() {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(Module &m) {
|
||||
std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(Module m) {
|
||||
auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
assert(dialect && "LLVM dialect must be registered");
|
||||
|
||||
|
|
|
@ -1128,7 +1128,7 @@ auto ConversionTarget::getOpAction(OperationName op) const
|
|||
/// conversion object. If conversion fails for specific functions, those
|
||||
/// functions remains unmodified.
|
||||
LogicalResult
|
||||
mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
|
||||
mlir::applyConversionPatterns(Module module, ConversionTarget &target,
|
||||
TypeConverter &converter,
|
||||
OwningRewritePatternList &&patterns) {
|
||||
SmallVector<Function, 32> allFunctions(module.getFunctions());
|
||||
|
|
|
@ -555,8 +555,8 @@ TEST_FUNC(vectorize_2d) {
|
|||
makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType});
|
||||
|
||||
mlir::Function f = owningF;
|
||||
mlir::Module module(&globalContext());
|
||||
module.push_back(f);
|
||||
mlir::OwningModuleRef module = Module::create(&globalContext());
|
||||
module->push_back(f);
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
|
|
@ -89,8 +89,8 @@ static llvm::cl::list<std::string>
|
|||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
|
||||
llvm::cl::cat(clOptionsCategory));
|
||||
|
||||
static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
|
||||
MLIRContext *context) {
|
||||
static OwningModuleRef parseMLIRInput(StringRef inputFilename,
|
||||
MLIRContext *context) {
|
||||
// Set up the input file.
|
||||
std::string errorMessage;
|
||||
auto file = openInputFile(inputFilename, &errorMessage);
|
||||
|
@ -101,7 +101,7 @@ static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
|
|||
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
|
||||
return std::unique_ptr<Module>(parseSourceFile(sourceMgr, context));
|
||||
return OwningModuleRef(parseSourceFile(sourceMgr, context));
|
||||
}
|
||||
|
||||
// Initialize the relevant subsystems of LLVM.
|
||||
|
@ -151,7 +151,7 @@ static void printMemRefArguments(ArrayRef<Type> argTypes,
|
|||
// - canonicalization
|
||||
// - affine to standard lowering
|
||||
// - standard to llvm lowering
|
||||
static LogicalResult convertAffineStandardToLLVMIR(Module *module) {
|
||||
static LogicalResult convertAffineStandardToLLVMIR(Module module) {
|
||||
PassManager manager;
|
||||
manager.addPass(mlir::createCanonicalizerPass());
|
||||
manager.addPass(mlir::createCSEPass());
|
||||
|
@ -161,9 +161,9 @@ static LogicalResult convertAffineStandardToLLVMIR(Module *module) {
|
|||
}
|
||||
|
||||
static Error compileAndExecuteFunctionWithMemRefs(
|
||||
Module *module, StringRef entryPoint,
|
||||
Module module, StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
Function mainFunction = module->getNamedFunction(entryPoint);
|
||||
Function mainFunction = module.getNamedFunction(entryPoint);
|
||||
if (!mainFunction || mainFunction.getBlocks().empty()) {
|
||||
return make_string_error("entry point not found");
|
||||
}
|
||||
|
@ -204,9 +204,9 @@ static Error compileAndExecuteFunctionWithMemRefs(
|
|||
}
|
||||
|
||||
static Error compileAndExecuteSingleFloatReturnFunction(
|
||||
Module *module, StringRef entryPoint,
|
||||
Module module, StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
Function mainFunction = module->getNamedFunction(entryPoint);
|
||||
Function mainFunction = module.getNamedFunction(entryPoint);
|
||||
if (!mainFunction || mainFunction.isExternal()) {
|
||||
return make_string_error("entry point not found");
|
||||
}
|
||||
|
|
|
@ -26,19 +26,19 @@ namespace {
|
|||
/// Minimal class definitions for two analyses.
|
||||
struct MyAnalysis {
|
||||
MyAnalysis(Function) {}
|
||||
MyAnalysis(Module *) {}
|
||||
MyAnalysis(Module) {}
|
||||
};
|
||||
struct OtherAnalysis {
|
||||
OtherAnalysis(Function) {}
|
||||
OtherAnalysis(Module *) {}
|
||||
OtherAnalysis(Module) {}
|
||||
};
|
||||
|
||||
TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
|
||||
MLIRContext context;
|
||||
|
||||
// Test fine grain invalidation of the module analysis manager.
|
||||
std::unique_ptr<Module> module(new Module(&context));
|
||||
ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
|
||||
OwningModuleRef module(Module::create(&context));
|
||||
ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
|
||||
|
||||
// Query two different analyses, but only preserve one before invalidating.
|
||||
mam.getAnalysis<MyAnalysis>();
|
||||
|
@ -58,14 +58,14 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
|
|||
Builder builder(&context);
|
||||
|
||||
// Create a function and a module.
|
||||
std::unique_ptr<Module> module(new Module(&context));
|
||||
OwningModuleRef module(Module::create(&context));
|
||||
Function func1 =
|
||||
Function::create(builder.getUnknownLoc(), "foo",
|
||||
builder.getFunctionType(llvm::None, llvm::None));
|
||||
module->push_back(func1);
|
||||
|
||||
// Test fine grain invalidation of the function analysis manager.
|
||||
ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
|
||||
ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
|
||||
FunctionAnalysisManager fam = mam.slice(func1);
|
||||
|
||||
// Query two different analyses, but only preserve one before invalidating.
|
||||
|
@ -86,7 +86,7 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
|
|||
Builder builder(&context);
|
||||
|
||||
// Create a function and a module.
|
||||
std::unique_ptr<Module> module(new Module(&context));
|
||||
OwningModuleRef module(Module::create(&context));
|
||||
Function func1 =
|
||||
Function::create(builder.getUnknownLoc(), "foo",
|
||||
builder.getFunctionType(llvm::None, llvm::None));
|
||||
|
@ -94,7 +94,7 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
|
|||
|
||||
// Test fine grain invalidation of a function analysis from within a module
|
||||
// analysis manager.
|
||||
ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
|
||||
ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
|
||||
|
||||
// Query two different analyses, but only preserve one before invalidating.
|
||||
mam.getFunctionAnalysis<MyAnalysis>(func1);
|
||||
|
|
Loading…
Reference in New Issue