From d83959b6f894ef996e0496a94b0d3e04a332c514 Mon Sep 17 00:00:00 2001 From: zhangyanhui Date: Fri, 9 Dec 2022 17:34:58 +0800 Subject: [PATCH] bugfix for UpdateWeights & ExportModel --- mindspore/lite/src/train/train_session.cc | 7 ++++++- mindspore/lite/src/train/train_session.h | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 86c78b054c9..6ff89a0a383 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -228,9 +228,13 @@ int TrainSession::UpdateWeights(std::vector modify_tensors) { return RET_PARAM_INVALID; } if (modify->tensor_name() == tensor->tensor_name()) { + if (tensor->Size() != modify->Size()) { + model_buff_changed_ = true; + } auto ret = ReshapeWeightTensor(tensor, modify); num_of_found_tensors++; if (ret != RET_OK) { + model_buff_changed_ = false; return ret; } break; @@ -244,6 +248,7 @@ int TrainSession::UpdateWeights(std::vector modify_tensors) { auto ret = ReSizeKernels(kernels_); if (ret != RET_OK) { MS_LOG(ERROR) << "Resize kernels fail!"; + model_buff_changed_ = false; return ret; } @@ -1183,7 +1188,7 @@ int TrainSession::ExportInner(DestType destination, ModelType model_type, Quanti TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed."); status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type); } else { - if ((quant_type == QT_NONE) && (model_type == MT_TRAIN) && + if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) && std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) { return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE; })) { diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index 628db537b19..130296a5e57 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -177,6 +177,7 @@ class TrainSession : virtual public lite::LiteSession { void *workspace_ = nullptr; SchedCallBack sched_mix_precision_callback_; bool train_mode_ = false; + bool model_buff_changed_ = false; void *tensors_data_ = nullptr; size_t tensors_data_size_ = 0; std::shared_ptr allocator_;