forked from mindspore-Ecosystem/mindspore
!46660 bugfix for UpdateWeights & ExportModel
Merge pull request !46660 from zhangyanhui/develop_mas
This commit is contained in:
commit
3223ed3fad
|
@ -228,9 +228,13 @@ int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) {
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
if (modify->tensor_name() == tensor->tensor_name()) {
|
if (modify->tensor_name() == tensor->tensor_name()) {
|
||||||
|
if (tensor->Size() != modify->Size()) {
|
||||||
|
model_buff_changed_ = true;
|
||||||
|
}
|
||||||
auto ret = ReshapeWeightTensor(tensor, modify);
|
auto ret = ReshapeWeightTensor(tensor, modify);
|
||||||
num_of_found_tensors++;
|
num_of_found_tensors++;
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
|
model_buff_changed_ = false;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
@ -244,6 +248,7 @@ int TrainSession::UpdateWeights(std::vector<lite::Tensor *> modify_tensors) {
|
||||||
auto ret = ReSizeKernels(kernels_);
|
auto ret = ReSizeKernels(kernels_);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Resize kernels fail!";
|
MS_LOG(ERROR) << "Resize kernels fail!";
|
||||||
|
model_buff_changed_ = false;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1183,7 +1188,7 @@ int TrainSession::ExportInner(DestType destination, ModelType model_type, Quanti
|
||||||
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
|
TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed.");
|
||||||
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
|
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
|
||||||
} else {
|
} else {
|
||||||
if ((quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
|
if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) &&
|
||||||
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
|
std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) {
|
||||||
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
|
return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE;
|
||||||
})) {
|
})) {
|
||||||
|
|
|
@ -177,6 +177,7 @@ class TrainSession : virtual public lite::LiteSession {
|
||||||
void *workspace_ = nullptr;
|
void *workspace_ = nullptr;
|
||||||
SchedCallBack sched_mix_precision_callback_;
|
SchedCallBack sched_mix_precision_callback_;
|
||||||
bool train_mode_ = false;
|
bool train_mode_ = false;
|
||||||
|
bool model_buff_changed_ = false;
|
||||||
void *tensors_data_ = nullptr;
|
void *tensors_data_ = nullptr;
|
||||||
size_t tensors_data_size_ = 0;
|
size_t tensors_data_size_ = 0;
|
||||||
std::shared_ptr<Allocator> allocator_;
|
std::shared_ptr<Allocator> allocator_;
|
||||||
|
|
Loading…
Reference in New Issue