From 374012e1d2f74de29edec6ad3c7816950f52ae9d Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Mon, 23 Nov 2020 13:26:11 +0800 Subject: [PATCH] fix train code --- mindspore/lite/src/train/loss_kernel.h | 6 ++-- mindspore/lite/src/train/train_model.cc | 11 ++++--- mindspore/lite/src/train/train_session.cc | 39 ++++++++++++++--------- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/mindspore/lite/src/train/loss_kernel.h b/mindspore/lite/src/train/loss_kernel.h index 07484b5ecf5..0df3522a523 100644 --- a/mindspore/lite/src/train/loss_kernel.h +++ b/mindspore/lite/src/train/loss_kernel.h @@ -22,9 +22,9 @@ namespace mindspore::kernel { class LossKernel : public LiteKernel { public: LossKernel() = default; - explicit LossKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::InnerContext *ctx, - const lite::PrimitiveC *primitive) + LossKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~LossKernel() = default; }; diff --git a/mindspore/lite/src/train/train_model.cc b/mindspore/lite/src/train/train_model.cc index 0cb759be238..ee93a62c47e 100644 --- a/mindspore/lite/src/train/train_model.cc +++ b/mindspore/lite/src/train/train_model.cc @@ -49,12 +49,14 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version); if (meta_graph == nullptr) { MS_LOG(ERROR) << "meta_graph is nullptr!"; + free(model->buf); delete (model); return nullptr; } int status = GenerateModelByVersion(meta_graph, model, schema_version); if (status != RET_OK) { + free(model->buf); delete (model); MS_LOG(ERROR) << "fail to generate model"; return nullptr; @@ -73,17 +75,16 @@ char *TrainModel::ExportBuf(char *buffer, size_t *len) const { MS_LOG(ERROR) << "Model::Export is only available for Train Session"; return nullptr; } - if (*len < buf_size_ && buffer != nullptr) { MS_LOG(ERROR) << "Buffer is too small, Export Failed"; return nullptr; } if (buffer == nullptr) { buffer = reinterpret_cast(malloc(buf_size_)); - } - if (buffer == nullptr) { - MS_LOG(ERROR) << "allocated model buf fail!"; - return nullptr; + if (buffer == nullptr) { + MS_LOG(ERROR) << "allocated model buf fail!"; + return nullptr; + } } memcpy(buffer, buf, buf_size_); diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 2d9bcfb7fab..2497a004c54 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -92,13 +92,22 @@ void TrainSession::AllocWorkSpace() { int TrainSession::CompileGraph(lite::Model *model) { return lite::RET_ERROR; } int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) { + if (model == nullptr) { + MS_LOG(ERROR) << "model is null"; + return RET_ERROR; + } model_ = model; - auto restore = ReplaceOps(); auto ret = lite::LiteSession::CompileGraph(model); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Compile train graph failed"; + return RET_ERROR; + } orig_output_map_ = output_node_map_; orig_output_tensor_map_ = output_tensor_map_; - for (auto inTensor : inputs_) inTensor->MutableData(); + for (auto inTensor : inputs_) { + inTensor->MutableData(); + } RestoreOps(restore); AllocWorkSpace(); MarkOptimizedKernels(); @@ -152,7 +161,7 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a int TrainSession::SaveToFile(const std::string &filename) const { size_t fb_size = 0; auto *buf = reinterpret_cast(ExportToBuf(nullptr, &fb_size)); - if (buf == NULL) { + if (buf == nullptr) { MS_LOG(ERROR) << "Could not Export Trained model"; return lite::RET_NULL_PTR; } @@ -212,7 +221,7 @@ int TrainSession::Train() { } void TrainSession::UpdateOutputMapByLossKernel(const kernel::LiteKernel *kernel) { - if (IsLossKernel(kernel)) { + if (kernel != nullptr && IsLossKernel(kernel)) { auto *ms_tensor = kernel->out_tensors().at(0); if (ms_tensor != nullptr) { (void)ms_tensor->MutableData(); @@ -226,7 +235,7 @@ void TrainSession::UpdateOutputMapByLossKernel(const kernel::LiteKernel *kernel) } void TrainSession::UpdateOutputMapByInKernel(const kernel::LiteKernel *kernel) { - if (IsLossKernel(kernel)) { + if (kernel != nullptr && IsLossKernel(kernel)) { for (auto in_kernel : kernel->in_kernels()) { if (output_node_map_.find(in_kernel->name()) == output_node_map_.end()) { auto *ms_tensor = in_kernel->out_tensors().at(0); @@ -304,9 +313,9 @@ void TrainSession::BuildInferenceKernelsMap() { } } else { auto sub_graph = reinterpret_cast(kernel); - for (auto sb_kernel : sub_graph->nodes()) { - if (IsLossKernel(sb_kernel)) { // For each loss in the system add backward tree - for (auto in_node : sb_kernel->in_kernels()) { + for (auto sub_kernel : sub_graph->nodes()) { + if (IsLossKernel(sub_kernel)) { // For each loss in the system add backward tree + for (auto in_node : sub_kernel->in_kernels()) { BuildInferenceKernelsRecursive(in_node, &req_kernels); } } @@ -357,9 +366,9 @@ void TrainSession::MarkOptimizedKernels() { } } else { auto sub_graph = reinterpret_cast(kernel); - for (auto sb_kernel : sub_graph->nodes()) { - if (IsOptimizer(sb_kernel)) { - std::copy(sb_kernel->in_tensors().begin(), sb_kernel->in_tensors().end(), std::back_inserter(ot)); + for (auto sub_kernel : sub_graph->nodes()) { + if (IsOptimizer(sub_kernel)) { + std::copy(sub_kernel->in_tensors().begin(), sub_kernel->in_tensors().end(), std::back_inserter(ot)); } } } @@ -376,11 +385,11 @@ void TrainSession::MarkOptimizedKernels() { } } else { auto sub_graph = reinterpret_cast(kernel); - for (auto sb_kernel : sub_graph->nodes()) { - if (!IsOptimizer(sb_kernel)) { - for (auto it : sb_kernel->in_tensors()) { + for (auto sub_kernel : sub_graph->nodes()) { + if (!IsOptimizer(sub_kernel)) { + for (auto it : sub_kernel->in_tensors()) { if (std::find(ot.begin(), ot.end(), it) != ot.end()) { - sb_kernel->set_trainable(true); + sub_kernel->set_trainable(true); break; } }