!4915 [MS][LITE][Develop]mem check fixed

Merge pull request !4915 from wangchangkai/master
This commit is contained in:
mindspore-ci-bot 2020-08-22 11:33:55 +08:00 committed by Gitee
commit 42a092d687
6 changed files with 48 additions and 2 deletions

View File

@ -29,7 +29,14 @@ namespace mindspore {
class ParamValueLite : public Value {
public:
ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {}
virtual ~ParamValueLite() = default;
virtual ~ParamValueLite() {
if (tensor_addr_ != nullptr) {
auto tensor_mem = reinterpret_cast<char*>(tensor_addr_);
delete tensor_mem;
tensor_addr_ = nullptr;
tensor_size_ = 0;
}
}
size_t tensor_size() const { return tensor_size_; }
void set_tensor_size(size_t size) { tensor_size_ = size; }

View File

@ -137,6 +137,7 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod
std::string initial_data = initialize_proto.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data());
MS_EXCEPTION_IF_NULL(tensor_data_buf);
tensor_info->SetData(nullptr);
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size());
if (EOK != ret) {
MS_LOG(ERROR) << "memcpy_s error";
@ -152,6 +153,7 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod
param_value->set_tensor_type(tensor_info->data_type());
param_value->set_tensor_shape(tensor_info->shape());
node->set_default_param(param_value);
delete tensor_info;
}
anfnode_build_map_[value_proto.name()] = node;
return true;

View File

@ -58,7 +58,29 @@ class MindsporeImporter : public Converter {
~MindsporeImporter() override = default;
};
void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto cnodes = func_graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
return;
}
auto primT = primitiveT_value->GetPrimitiveT();
if (primT == nullptr) {
MS_LOG(ERROR) << "PrimitiveT is nullptr";
return;
}
if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
primT->value.type == schema::PrimitiveType_MakeTuple ||
primT->value.type == schema::PrimitiveType_Return) {
delete primT;
primitiveT_value->SetPrimitiveT(nullptr);
}
}
return;
}
MetaGraphT *Converter::Convert(const converter::Flags *flag) {
// parse the model and weight file to generate inference data structure
FuncGraphPtr graph = nullptr;
@ -116,6 +138,8 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
MS_LOG(ERROR) << "FBTransform model failed " << status;
return nullptr;
}
FreeFuncGraph(graph);
return meta_graph;
}
@ -171,6 +195,7 @@ int RunConverter(int argc, const char **argv) {
auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile);
MindsporeImporter mindsporeImporter(onnx_graph, graph);
fb_graph = mindsporeImporter.Convert(flags.get());
delete onnx_graph;
break;
}
case FmkType::FmkType_CAFFE: {
@ -202,6 +227,8 @@ int RunConverter(int argc, const char **argv) {
MS_LOG(ERROR) << "Save graph failed";
return 1;
}
delete fb_graph;
MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!";
return 0;

View File

@ -35,6 +35,7 @@ class Converter {
virtual ~Converter();
virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags);
void CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags);
void FreeFuncGraph(const FuncGraphPtr &func_graph);
protected:
ModelParser *modelParser = nullptr;

View File

@ -52,6 +52,7 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
// when tensorT as graph input
if (lite_tensor_size == 0) {
delete lite_tensor;
return input_tensors;
}
auto tensor_data = new (std::nothrow) char[lite_tensor_size / sizeof(char)];

View File

@ -100,6 +100,14 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
MS_LOG(EXCEPTION) << "Unsupported opType, " << type;
}
pre_node->set_abstract(abstr);
const auto &prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(transform_node->input(0));
if (prim != nullptr) {
auto *prim_t = prim->GetPrimitiveT();
if (prim_t != nullptr) {
delete prim_t;
prim->SetPrimitiveT(nullptr);
}
}
return pre_node;
}