!46660 bugfix for UpdateWeights & ExportModel

Merge pull request !46660 from zhangyanhui/develop_mas
This commit is contained in:
i-robot 2022-12-10 03:51:47 +00:00 committed by Gitee
commit 3223ed3fad
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 7 additions and 1 deletions

View File

@ -228,9 +228,13 @@ int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) {
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
if (modify->tensor_name() == tensor->tensor_name()) { if (modify->tensor_name() == tensor->tensor_name()) {
if (tensor->Size() != modify->Size()) {
model_buff_changed_ = true;
}
auto ret = ReshapeWeightTensor(tensor, modify); auto ret = ReshapeWeightTensor(tensor, modify);
num_of_found_tensors++; num_of_found_tensors++;
if (ret != RET_OK) { if (ret != RET_OK) {
model_buff_changed_ = false;
return ret; return ret;
} }
break; break;
@ -244,6 +248,7 @@ int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) {
auto ret = ReSizeKernels(kernels_); auto ret = ReSizeKernels(kernels_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Resize kernels fail!"; MS_LOG(ERROR) << "Resize kernels fail!";
model_buff_changed_ = false;
return ret; return ret;
} }
@ -1183,7 +1188,7 @@ int TrainSession::ExportInner(DestType destination, ModelType model_type, Quanti
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed."); TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type); status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
} else { } else {
if ((quant_type == QT_NONE) && (model_type == MT_TRAIN) && if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) { std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE; return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
})) { })) {

View File

@ -177,6 +177,7 @@ class TrainSession : virtual public lite::LiteSession {
void *workspace_ = nullptr; void *workspace_ = nullptr;
SchedCallBack sched_mix_precision_callback_; SchedCallBack sched_mix_precision_callback_;
bool train_mode_ = false; bool train_mode_ = false;
bool model_buff_changed_ = false;
void *tensors_data_ = nullptr; void *tensors_data_ = nullptr;
size_t tensors_data_size_ = 0; size_t tensors_data_size_ = 0;
std::shared_ptr<Allocator> allocator_; std::shared_ptr<Allocator> allocator_;