bugfix for UpdateWeights & ExportModel

This commit is contained in:
zhangyanhui 2022-12-09 17:34:58 +08:00
parent d75517ad31
commit d83959b6f8
2 changed files with 7 additions and 1 deletions

View File

@ -228,9 +228,13 @@ int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) {
return RET_PARAM_INVALID;
}
if (modify->tensor_name() == tensor->tensor_name()) {
if (tensor->Size() != modify->Size()) {
model_buff_changed_ = true;
}
auto ret = ReshapeWeightTensor(tensor, modify);
num_of_found_tensors++;
if (ret != RET_OK) {
model_buff_changed_ = false;
return ret;
}
break;
@ -244,6 +248,7 @@ int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) {
auto ret = ReSizeKernels(kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Resize kernels fail!";
model_buff_changed_ = false;
return ret;
}
@ -1183,7 +1188,7 @@ int TrainSession::ExportInner(DestType destination, ModelType model_type, Quanti
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
} else {
if ((quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
})) {

View File

@ -177,6 +177,7 @@ class TrainSession : virtual public lite::LiteSession {
void *workspace_ = nullptr;
SchedCallBack sched_mix_precision_callback_;
bool train_mode_ = false;
bool model_buff_changed_ = false;
void *tensors_data_ = nullptr;
size_t tensors_data_size_ = 0;
std::shared_ptr<Allocator> allocator_;