[mlir-reduce] Fix the memory leak and recycle unused modules.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D105416
This commit is contained in:
Chia-hung Duan 2021-07-08 20:03:23 +08:00
parent 026bb84bcd
commit ba913b8da5
3 changed files with 12 additions and 7 deletions

View File

@ -20,6 +20,7 @@
#include <queue>
#include <vector>
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
@ -57,7 +58,7 @@ public:
/// will have been applied certain reduction strategies. Note that it's not
/// necessary to be an interesting case or a reduced module (has smaller size
/// than parent's).
ModuleOp getModule() const { return module; }
ModuleOp getModule() const { return module.get(); }
/// Return the region we're reducing.
Region &getRegion() const { return *region; }
@ -141,7 +142,7 @@ private:
/// This is a copy of module from parent node. All the reducer patterns will
/// be applied to this instance.
ModuleOp module;
OwningOpRef<ModuleOp> module;
/// The region of certain operation we're reducing in the module
Region *region;

View File

@ -112,6 +112,9 @@ void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
// This module may has been updated. Reset the range.
ranges.clear();
ranges.push_back({0, std::distance(region->op_begin(), region->op_end())});
} else {
// Release the uninteresting module to save some memory.
module.release()->erase();
}
}

View File

@ -28,7 +28,8 @@
using namespace mlir;
// Parse and verify the input MLIR file.
static LogicalResult loadModule(MLIRContext &context, OwningModuleRef &module,
static LogicalResult loadModule(MLIRContext &context,
OwningOpRef<ModuleOp> &module,
StringRef inputFilename) {
module = parseSourceFile(inputFilename, &context);
if (!module)
@ -75,7 +76,7 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
if (!output)
return failure();
mlir::OwningModuleRef moduleRef;
OwningOpRef<ModuleOp> moduleRef;
if (failed(loadModule(context, moduleRef, inputFilename)))
return failure();
@ -88,12 +89,12 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
if (failed(parser.addToPipeline(pm, errorHandler)))
return failure();
ModuleOp m = moduleRef.get().clone();
OwningOpRef<ModuleOp> m = moduleRef.get().clone();
if (failed(pm.run(m)))
if (failed(pm.run(m.get())))
return failure();
m.print(output->os());
m->print(output->os());
output->keep();
return success();