diff --git a/include/api/model.h b/include/api/model.h index cecda6499f7..8d34871527b 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -64,6 +64,14 @@ class MS_API Model { /// \return Status. Status Resize(const std::vector &inputs, const std::vector> &dims); + /// \brief Change the size and or content of weight tensors + /// + /// \param[in] new_weights a vector of tensors with new shapes and data to use in the model + /// If data pointer is null, the data of the original tensors will be copied to the new ones + /// + /// \return Status. + Status UpdateWeights(const std::vector &new_weights); + /// \brief Inference model. /// /// \param[in] inputs A vector where model inputs are arranged in sequence. diff --git a/mindspore/lite/include/lite_session.h b/mindspore/lite/include/lite_session.h index a33dbf88f77..edb35444d72 100644 --- a/mindspore/lite/include/lite_session.h +++ b/mindspore/lite/include/lite_session.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H -#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H +#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_ +#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_ #ifndef NOT_USE_STL #include @@ -190,6 +190,14 @@ class MS_API LiteSession { return mindspore::lite::RET_ERROR; } + /// \brief Change the size and or content of weight tensors + /// + /// \param[in] new_weights a vector of tensors with new shapes and data to use in the model + /// If data pointer is null, the data of the original tensors will be copied to the new ones + /// + /// \return STATUS as an error code of operation, STATUS is defined in errorcode.h. + virtual int UpdateWeights(std::vector new_weights) { return mindspore::lite::RET_ERROR; } + /// \brief Get model featuremap MindSpore Lite MSTensors of Training model prediction /// /// \return a vector of output tensors (MindSpore Lite MSTensor). @@ -233,4 +241,4 @@ class MS_API LiteSession { }; } // namespace session } // namespace mindspore -#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H +#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_ diff --git a/mindspore/lite/src/cxx_api/model/model.cc b/mindspore/lite/src/cxx_api/model/model.cc index a326bad8cbc..1cd1f2c3972 100644 --- a/mindspore/lite/src/cxx_api/model/model.cc +++ b/mindspore/lite/src/cxx_api/model/model.cc @@ -102,6 +102,14 @@ Status Model::Resize(const std::vector &inputs, const std::vectorResize(inputs, dims); } +Status Model::UpdateWeights(const std::vector &new_weights) { + if (impl_ == nullptr) { + MS_LOG(ERROR) << "Model implement is null."; + return kLiteNullptr; + } + return impl_->UpdateWeights(new_weights); +} + Status Model::Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before, const MSKernelCallBack &after) { if (impl_ == nullptr) { diff --git a/mindspore/lite/src/cxx_api/model/model_impl.cc b/mindspore/lite/src/cxx_api/model/model_impl.cc index 1359d4a3544..5e86996ce8a 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.cc +++ b/mindspore/lite/src/cxx_api/model/model_impl.cc @@ -559,6 +559,29 @@ Status ModelImpl::Resize(const std::vector &inputs, const std::vector< return static_cast(ret); } +Status ModelImpl::UpdateWeights(const std::vector &new_weights) { + if (session_ == nullptr) { + MS_LOG(ERROR) << "Session is null."; + return kLiteNullptr; + } + if (new_weights.empty()) { + MS_LOG(ERROR) << "New weights are empty."; + return kLiteInputParamInvalid; + } + std::vector inner_weights; + inner_weights.resize(new_weights.size()); + for (size_t i = 0; i < new_weights.size(); i++) { + auto weight = new_weights[i]; + if (weight.impl_ == nullptr || weight.impl_->lite_tensor() == nullptr) { + MS_LOG(ERROR) << "Input tensor " << weight.Name() << " is null."; + return kLiteInputTensorError; + } + inner_weights[i] = weight.impl_->lite_tensor(); + } + auto ret = session_->UpdateWeights(inner_weights); + return static_cast(ret); +} + session::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) { auto session = new (std::nothrow) lite::LiteSession(); if (session == nullptr) { diff --git a/mindspore/lite/src/cxx_api/model/model_impl.h b/mindspore/lite/src/cxx_api/model/model_impl.h index 99935d60ea6..90a6533cb90 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.h +++ b/mindspore/lite/src/cxx_api/model/model_impl.h @@ -63,6 +63,7 @@ class ModelImpl { const std::shared_ptr &model_context); Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context); Status Resize(const std::vector &inputs, const std::vector> &dims); + Status UpdateWeights(const std::vector &new_weights); Status Predict(const std::vector &inputs, std::vector *outputs, const MSKernelCallBack &before, const MSKernelCallBack &after); diff --git a/mindspore/lite/src/cxx_api/types.cc b/mindspore/lite/src/cxx_api/types.cc index 60830ead88f..4ea737940ca 100644 --- a/mindspore/lite/src/cxx_api/types.cc +++ b/mindspore/lite/src/cxx_api/types.cc @@ -91,7 +91,7 @@ MSTensor *MSTensor::CreateTensor(const std::vector &name, enum DataType ty return nullptr; } if (data_len > 0 && data == nullptr) { - MS_LOG(ERROR) << "Mull data ptr of tensor."; + MS_LOG(ERROR) << "Null data ptr of tensor."; return nullptr; } auto impl = Impl::CreateTensorImpl(CharToString(name), type, shape, nullptr, data_len); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc index 085c381ee8e..eb03a177c0c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc @@ -28,7 +28,7 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits; namespace mindspore::kernel { -int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; } +int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return Prepare(); } int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses, float *output) const { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc index 5141274fbcb..a603664e513 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/strided_slice_grad.cc @@ -50,8 +50,6 @@ int StridedSliceGradCPUKernel::Prepare() { MS_LOG(ERROR) << "Not supported data type: " << input->data_type(); return RET_ERROR; } - FillEmptyDims(); - FillOutputDim(); return ReSize(); } @@ -113,7 +111,11 @@ void StridedSliceGradCPUKernel::FillOutputDim() { } } -int StridedSliceGradCPUKernel::ReSize() { return RET_OK; } +int StridedSliceGradCPUKernel::ReSize() { + FillEmptyDims(); + FillOutputDim(); + return RET_OK; +} int StridedSliceGradImpl(void *cdata, int task_id, float lhs_scale, float rhs_scale) { CHECK_NULL_RETURN(cdata); diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 0a5bb42991c..ea524b4ff2d 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -176,6 +176,89 @@ int TrainSession::InitCallBack() { return RET_OK; } +static int ReshapeWeightTensor(Tensor *orig_tensor, tensor::MSTensor *new_tensor) { + if (orig_tensor->data_type() != new_tensor->data_type()) { + MS_LOG(ERROR) << "Cannot reshape tensor of different type: " << new_tensor->tensor_name(); + return RET_PARAM_INVALID; + } + + if (orig_tensor->category() != lite::Category::CONST_TENSOR) { + MS_LOG(ERROR) << "Cannot reshape non const tensor: " << new_tensor->tensor_name(); + return RET_ERROR; + } + + auto orig_size = orig_tensor->Size(); + uint8_t *new_data = reinterpret_cast(new_tensor->data()); + if (new_data == nullptr) { + // Copy original data into new_tensor + new_data = reinterpret_cast(new_tensor->MutableData()); + if (new_data == nullptr) { + MS_LOG(ERROR) << "Allocation of Data Failed" << new_tensor->tensor_name(); + return RET_ERROR; + } + if (orig_size == 0) { + MS_LOG(ERROR) << "Operation failed: Both new tensors and original one have no data"; + return RET_ERROR; + } + uint8_t *orig_data = reinterpret_cast(orig_tensor->data()); + for (unsigned int loc = 0; loc < new_tensor->Size(); loc++) { + new_data[loc] = orig_data[loc % orig_size]; + } + } + + orig_tensor->FreeData(); + orig_tensor->set_data(nullptr); + orig_tensor->set_shape(new_tensor->shape()); + + uint8_t *dst_data = reinterpret_cast(orig_tensor->MutableData()); + if (dst_data == nullptr) { + MS_LOG(ERROR) << "Allocation of Data Failed"; + return RET_ERROR; + } + std::copy(new_data, new_data + orig_tensor->Size(), dst_data); + return RET_OK; +} + +int TrainSession::UpdateWeights(std::vector modify_tensors) { + unsigned int num_of_found_tensors = 0; + for (auto tensor : tensors_) { + for (auto modify : modify_tensors) { + if (modify == nullptr) { + MS_LOG(ERROR) << "Tensor is nullptr"; + return RET_PARAM_INVALID; + } + if (modify->tensor_name() == tensor->tensor_name()) { + auto ret = ReshapeWeightTensor(tensor, modify); + num_of_found_tensors++; + if (ret != RET_OK) { + return ret; + } + break; + } + } + } + if (num_of_found_tensors != modify_tensors.size()) { + MS_LOG(ERROR) << "Did not find all the given tensors in the model"; + return RET_ERROR; + } + auto ret = ReSizeKernels(kernels_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Resize kernels fail!"; + return ret; + } + + bool is_eval = IsEval(); + ret = Train(); // This will trigger proper Allocation of static data; + if (ret != RET_OK) { + MS_LOG(ERROR) << "General failure occurred during Update of Weights"; + return ret; + } + if (is_eval) { + ret = Eval(); + } + return ret; +} + int TrainSession::AllocTensors(const std::vector &kernels) { if (!IS_STATIC_ALLOCATOR(allocator_)) return RET_OK; OptAllocator allocator; @@ -199,8 +282,12 @@ int TrainSession::AllocTensors(const std::vector &kernels) } } // Set Tensor data + auto size = allocator.total_size(); + if (size > tensors_data_size_) { + free(tensors_data_); + tensors_data_ = nullptr; + } if (tensors_data_ == nullptr) { - auto size = allocator.total_size(); auto buf = malloc(size); if (buf == nullptr) { MS_LOG(ERROR) << "cannot allocate buffer size" << size; @@ -209,6 +296,7 @@ int TrainSession::AllocTensors(const std::vector &kernels) StaticAllocator *alloc = reinterpret_cast(allocator_.get()); alloc->SetContex(buf, size); tensors_data_ = buf; + tensors_data_size_ = size; } for (auto kernel : train_kernels_) { for (auto tensor : kernel->out_tensors()) { diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index 7ac679de060..d125146bf64 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -85,6 +85,7 @@ class TrainSession : virtual public lite::LiteSession { return lite::LiteSession::GetOutputByTensorName(tensor_name); } int Resize(const std::vector &inputs, const std::vector> &dims) override; + int UpdateWeights(std::vector new_weights) override; std::vector GetPredictions() const override { std::vector outputs; @@ -166,6 +167,7 @@ class TrainSession : virtual public lite::LiteSession { SchedCallBack sched_mix_precision_callback_; bool train_mode_ = false; void *tensors_data_ = nullptr; + unsigned int tensors_data_size_ = 0; std::shared_ptr allocator_; }; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc index e10d02f3123..2ea6bec5a1e 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc @@ -229,4 +229,29 @@ TEST_F(TestCxxApiLiteModel, test_fp16_SUCCESS) { train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true; ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); } + +#define NUM_OF_CLASSES 10 +#define FEATURE_SIZE 10 +TEST_F(TestCxxApiLiteModel, set_weights_FAILURE) { + Model model; + Graph graph; + auto context = std::make_shared(); + auto cpu_context = std::make_shared(); + cpu_context->SetEnableFP16(true); + context->MutableDeviceInfo().push_back(cpu_context); + auto train_cfg = std::make_shared(); + train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true; + + ASSERT_TRUE(Serialization::Load("./nets/mix_lenet_tod.ms", ModelType::kMindIR, &graph) == kSuccess); + ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); + std::vector changes; + ASSERT_TRUE(model.UpdateWeights(changes) != kSuccess); + changes.push_back( + *MSTensor::CreateTensor("fc4.weight", mindspore::DataType::kNumberTypeFloat32, {NUM_OF_CLASSES}, nullptr, 0)); + ASSERT_TRUE(model.UpdateWeights(changes) != kSuccess); + changes.clear(); + changes.push_back( + *MSTensor::CreateTensor("fc3.bias", mindspore::DataType::kNumberTypeFloat32, {NUM_OF_CLASSES}, nullptr, 0)); + ASSERT_TRUE(model.UpdateWeights(changes) == kSuccess); +} } // namespace mindspore