From 059c8fb2832a358ee3f6e69342fe0a376c9e9e89 Mon Sep 17 00:00:00 2001 From: Emir Haleva Date: Tue, 21 Dec 2021 18:39:40 +0200 Subject: [PATCH] Extend cxx_api to other ToD methods --- include/api/model.h | 32 ++++++++++++++- .../examples/unified_api/src/net_runner.cc | 4 +- mindspore/lite/src/cxx_api/model/model.cc | 24 +++++++++++ .../lite/src/cxx_api/model/model_impl.cc | 27 +++++++++++++ mindspore/lite/src/cxx_api/model/model_impl.h | 4 ++ mindspore/lite/src/cxx_api/train/model.cc | 32 +++++++++++++++ .../lite/src/cxx_api/train/model_impl.cc | 40 +++++++++++++++++++ ...parse_softmax_cross_entropy_with_logits.cc | 5 ++- mindspore/lite/src/train/transfer_session.cc | 10 ++--- mindspore/lite/src/train/transfer_session.h | 4 ++ .../runtime/kernel/arm/cxx_api/model_test.cc | 18 +++++++++ .../quantizer/mixed_bit_weight_quantizer.cc | 8 ++-- .../converter/quantizer/parameter_tunner.cc | 8 ++-- 13 files changed, 199 insertions(+), 17 deletions(-) diff --git a/include/api/model.h b/include/api/model.h index 36e45711391..7d77f74e26a 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -45,7 +45,7 @@ class MS_API Model { Model(const Model &) = delete; void operator=(const Model &) = delete; - /// \brief Builds a model so that it can run on a device. + /// \brief Builds a model /// /// \param[in] graph GraphCell is a derivative of Cell. Cell is not available currently. GraphCell can be constructed /// from Graph, for example, model.Build(GraphCell(graph), context). @@ -56,6 +56,17 @@ class MS_API Model { Status Build(GraphCell graph, const std::shared_ptr &model_context = nullptr, const std::shared_ptr &train_cfg = nullptr); + /// \brief Builds a Transfer Learning model where the backbone weights are fixed and the head weights are trainable + /// + /// \param[in] backbone The static, non-learnable part of the graph + /// \param[in] head The trainable part of the graph + /// \param[in] context A context used to store options during execution + /// \param[in] cfg A config used by training + /// + /// \return Status + Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr &context, + const std::shared_ptr &train_cfg = nullptr); + /// \brief Resizes the shapes of inputs. /// /// \param[in] inputs A vector that includes all input tensors in order. @@ -173,6 +184,25 @@ class MS_API Model { /// \return Status of operation Status SetOptimizerParams(const std::vector ¶ms); + /// \brief Setup training with virtual batches + /// + /// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable + /// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration + /// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration + /// \return Status of operation + Status SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f); + + /// \brief Sets the Learning Rate of the training + /// + /// \param[in] learning_rate to set + /// \return Status of operation + Status SetLearningRate(float learning_rate); + + /// \brief Gets the Learning Rate of the optimizer + /// + /// \return learning rate. 0.0 if no optimizer was found + float GetLearningRate(); + Status InitMetrics(std::vector metrics); std::vector GetMetrics(); diff --git a/mindspore/lite/examples/unified_api/src/net_runner.cc b/mindspore/lite/examples/unified_api/src/net_runner.cc index 3953942bca0..f402b526913 100644 --- a/mindspore/lite/examples/unified_api/src/net_runner.cc +++ b/mindspore/lite/examples/unified_api/src/net_runner.cc @@ -207,6 +207,8 @@ int NetRunner::TrainLoop() { Measurement measure(epochs_); if (virtual_batch_ > 0) { + auto status = model_->SetupVirtualBatch(virtual_batch_); + MS_ASSERT(status == mindspore::kSuccess); model_->Train(epochs_, train_ds_, {&rescale, &lm, &cs, &measure}); } else { struct mindspore::StepLRLambda step_lr_lambda(1, kGammaFactor); @@ -237,7 +239,7 @@ int NetRunner::Main() { void NetRunner::Usage() { std::cout << "Usage: net_runner -f <.ms model file> -d [-e ] " - << "[-v (verbose mode)] [-s ]" << std::endl; + << "[-b ] [-v (verbose mode)] [-s ]" << std::endl; } bool NetRunner::ReadArgs(int argc, char *argv[]) { diff --git a/mindspore/lite/src/cxx_api/model/model.cc b/mindspore/lite/src/cxx_api/model/model.cc index 5d6de2d9049..f7672f13cd2 100644 --- a/mindspore/lite/src/cxx_api/model/model.cc +++ b/mindspore/lite/src/cxx_api/model/model.cc @@ -365,4 +365,28 @@ std::vector Model::GetMetrics() { return impl_->GetMetrics(); } +Status Model::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) { + if (impl_ == nullptr) { + MS_LOG(ERROR) << "Model implement is null."; + return kLiteUninitializedObj; + } + return impl_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum); +} + +Status Model::SetLearningRate(float learning_rate) { + if (impl_ == nullptr) { + MS_LOG(ERROR) << "Model implement is null."; + return kLiteUninitializedObj; + } + return impl_->SetLearningRate(learning_rate); +} + +float Model::GetLearningRate() { + if (impl_ == nullptr) { + MS_LOG(WARNING) << "Model implement is null."; + return 0.0; + } + return impl_->GetLearningRate(); +} + } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model/model_impl.cc b/mindspore/lite/src/cxx_api/model/model_impl.cc index d4601c5ca17..a24bc61a075 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.cc +++ b/mindspore/lite/src/cxx_api/model/model_impl.cc @@ -650,6 +650,32 @@ Status ModelImpl::UpdateWeights(const std::vector &new_weights) { return static_cast(ret); } +Status ModelImpl::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) { + if (session_ == nullptr) { + MS_LOG(ERROR) << "Session is null."; + return kLiteNullptr; + } + auto ret = session_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum); + return static_cast(ret); +} + +Status ModelImpl::SetLearningRate(float learning_rate) { + if (session_ == nullptr) { + MS_LOG(ERROR) << "Session is null."; + return kLiteNullptr; + } + auto ret = session_->SetLearningRate(learning_rate); + return static_cast(ret); +} + +float ModelImpl::GetLearningRate() { + if (session_ == nullptr) { + MS_LOG(WARNING) << "Session is null."; + return 0.0; + } + return session_->GetLearningRate(); +} + lite::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) { auto session = new (std::nothrow) lite::LiteSession(); if (session == nullptr) { @@ -669,4 +695,5 @@ lite::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) { } return session; } + } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model/model_impl.h b/mindspore/lite/src/cxx_api/model/model_impl.h index 2cf48991016..0499fc2a733 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.h +++ b/mindspore/lite/src/cxx_api/model/model_impl.h @@ -92,6 +92,10 @@ class ModelImpl { static bool CheckModelSupport(const std::string &device_type, ModelType model_type); bool IsTrainModel(); + Status SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum); + Status SetLearningRate(float learning_rate); + float GetLearningRate(); + Status BuildTransferLearning(const std::shared_ptr &backbone, const std::shared_ptr &head); Status InitMetrics(const std::vector metrics) { metrics_ = metrics; diff --git a/mindspore/lite/src/cxx_api/train/model.cc b/mindspore/lite/src/cxx_api/train/model.cc index 3eb176b1fcf..9e91a40c4ac 100644 --- a/mindspore/lite/src/cxx_api/train/model.cc +++ b/mindspore/lite/src/cxx_api/train/model.cc @@ -106,4 +106,36 @@ Status Model::Evaluate(std::shared_ptr ds, std::vector &context, + const std::shared_ptr &train_cfg) { + std::stringstream err_msg; + if (impl_ == nullptr) { + impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); + if (impl_ == nullptr) { + MS_LOG(ERROR) << "Model implement is null."; + return kLiteFileError; + } + } + + if (backbone.GetGraph() == nullptr || head.GetGraph() == nullptr) { + err_msg << "Invalid null graph."; + MS_LOG(ERROR) << err_msg.str(); + return Status(kLiteNullptr, err_msg.str()); + } + if (context == nullptr) { + err_msg << "Invalid null context."; + MS_LOG(ERROR) << err_msg.str(); + return Status(kLiteNullptr, err_msg.str()); + } + impl_->SetContext(context); + impl_->SetGraph(head.GetGraph()); + impl_->SetConfig(train_cfg); + + Status ret = impl_->BuildTransferLearning(backbone.GetGraph(), head.GetGraph()); + if (ret != kSuccess) { + return ret; + } + return kSuccess; +} + } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/train/model_impl.cc b/mindspore/lite/src/cxx_api/train/model_impl.cc index 0b0bc1fbec5..505bbcb315d 100644 --- a/mindspore/lite/src/cxx_api/train/model_impl.cc +++ b/mindspore/lite/src/cxx_api/train/model_impl.cc @@ -37,8 +37,48 @@ #include "src/cxx_api/callback/callback_impl.h" #include "src/common/log_adapter.h" #include "src/train/train_session.h" +#include "src/train/transfer_session.h" namespace mindspore { +Status ModelImpl::BuildTransferLearning(const std::shared_ptr &backbone, const std::shared_ptr &head) { + const auto b_graph_data = backbone->graph_data_; + const auto h_graph_data = head->graph_data_; + if (b_graph_data == nullptr || h_graph_data == nullptr) { + MS_LOG(ERROR) << "graph data cannot be nullptr"; + return kLiteNullptr; + } + bool is_train_session = h_graph_data->IsTrainModel(); + if (is_train_session) { + const auto b_model = reinterpret_cast(b_graph_data->lite_model().get()); + const auto h_model = reinterpret_cast(h_graph_data->lite_model().get()); + if (h_model == nullptr || h_model->buf == nullptr || b_model == nullptr || b_model->buf == nullptr) { + MS_LOG(ERROR) << "Lite model has been freed."; + return kLiteNullptr; + } + + lite::TrainCfg train_cfg; + if (cfg_ != nullptr) { + auto status = A2L_ConvertConfig(cfg_.get(), &train_cfg); + if (status != kSuccess) { + MS_LOG(ERROR) << "Failed to convert Config to Lite Config"; + return status; + } + } + + auto session = std::shared_ptr( + CreateTransferSessionInt(b_model->buf, b_model->buf_size_, h_model->buf, h_model->buf_size_, + ContextUtils::Convert(context_.get()), true, &train_cfg)); + if (session == nullptr) { + MS_LOG(ERROR) << "create session failed"; + return kLiteMemoryFailed; + } + session_.swap(session); + return kSuccess; + } + MS_LOG(DEBUG) << "Session is not a train session."; + return kLiteError; +} + Status ModelImpl::PrepareMetrics(Model *model, std::vector *out_ms, std::vector *adapter_ms) { if (out_ms == nullptr || adapter_ms == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc index 698eb3b44a2..9db82c3e727 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc @@ -94,7 +94,7 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Execute(int task_id) { int length = sm_params_.input_shape_[sm_params_.axis_]; int stride = UP_DIV(outter_size_, threads_); int count = MSMIN(stride, outter_size_ - stride * task_id); - if (count <= 0) return RET_ERROR; + if (count <= 0) return RET_OK; switch (stage_) { case 0: SoftMaxP1(ins, losses, sum_data, task_id * stride, count, length, inner_size_); @@ -145,7 +145,8 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { } inner_size_ = inner_size; outter_size_ = outter_size; - const std::vector threads = {op_parameter_->thread_num_, op_parameter_->thread_num_, 1}; + int max_num_of_threads = (outter_size_ < op_parameter_->thread_num_) ? outter_size_ : op_parameter_->thread_num_; + const std::vector threads = {max_num_of_threads, max_num_of_threads, 1}; for (int stage = 0; stage < static_cast(threads.size()); stage++) { stage_ = stage; threads_ = threads.at(stage); diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc index 28532cfeb43..fc98473584c 100644 --- a/mindspore/lite/src/train/transfer_session.cc +++ b/mindspore/lite/src/train/transfer_session.cc @@ -229,12 +229,10 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q if (orig_train_state) Train(); return status; } -} // namespace lite -static session::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone, - const char *model_buf_head, size_t size_head, - const lite::Context *context, bool train_mode, - const lite::TrainCfg *cfg) { +lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone, + const char *model_buf_head, size_t size_head, const lite::Context *context, + bool train_mode, const lite::TrainCfg *cfg) { auto ValidModelSize = [](size_t size) -> bool { constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL; // 1G B return size < MaxModelSize && size > 0; @@ -299,6 +297,8 @@ static session::LiteSession *CreateTransferSessionInt(const char *model_buf_back return session; } +} // namespace lite + session::LiteSession *session::TrainSession::CreateTransferSession(const std::string &filename_backbone, const std::string &filename_head, const lite::Context *ctxt, bool train_mode, diff --git a/mindspore/lite/src/train/transfer_session.h b/mindspore/lite/src/train/transfer_session.h index 04fca0e1582..ab30d539944 100644 --- a/mindspore/lite/src/train/transfer_session.h +++ b/mindspore/lite/src/train/transfer_session.h @@ -77,6 +77,10 @@ class TransferSession : public lite::TrainSession { bool nchw2nhwc_ = false; size_t size_backbone_; }; + +lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone, + const char *model_buf_head, size_t size_head, const lite::Context *context, + bool train_mode, const lite::TrainCfg *cfg); } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc index 81a9bd132c1..e18f8e117eb 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc @@ -233,4 +233,22 @@ TEST_F(TestCxxApiLiteModel, set_weights_FAILURE) { *MSTensor::CreateTensor("fc3.bias", mindspore::DataType::kNumberTypeFloat32, {NUM_OF_CLASSES}, nullptr, 0)); ASSERT_TRUE(model.UpdateWeights(changes) == kSuccess); } + +TEST_F(TestCxxApiLiteModel, set_get_lr_SUCCESS) { + Model model; + Graph graph; + float learn_rate = 0.2; + auto context = std::make_shared(); + auto cpu_context = std::make_shared(); + cpu_context->SetEnableFP16(true); + context->MutableDeviceInfo().push_back(cpu_context); + auto train_cfg = std::make_shared(); + + ASSERT_TRUE(Serialization::Load("./nets/mix_lenet_tod.ms", ModelType::kMindIR, &graph) == kSuccess); + ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); + + ASSERT_TRUE(model.SetLearningRate(learn_rate) == kSuccess); + ASSERT_TRUE(model.GetLearningRate() == learn_rate); +} + } // namespace mindspore diff --git a/mindspore/lite/tools/converter/quantizer/mixed_bit_weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/mixed_bit_weight_quantizer.cc index 11c2e55e915..2ac5adde549 100644 --- a/mindspore/lite/tools/converter/quantizer/mixed_bit_weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/mixed_bit_weight_quantizer.cc @@ -65,8 +65,8 @@ float MixedBitWeightQuantizer::CalculateMeanError(std::vector norms2, std error_count += 1; mse_error += sqrtf(dnorms2[i] / norms2[i]); } - auto meam_error = mse_error / (error_count + soft); - return meam_error; + auto mean_error = mse_error / (error_count + soft); + return mean_error; } // the `preferred` dim should point to the output channels dimension. @@ -109,8 +109,8 @@ float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const in float d = weights[i] - dequant; dnorms2[bucket] += d * d; } - auto meam_error = CalculateMeanError(norms2, dnorms2); - return meam_error; + auto mean_error = CalculateMeanError(norms2, dnorms2); + return mean_error; } MinMax MixedBitWeightQuantizer::GetMinMax(const float *arr, int arrc) { diff --git a/mindspore/lite/tools/converter/quantizer/parameter_tunner.cc b/mindspore/lite/tools/converter/quantizer/parameter_tunner.cc index 8223d1e2b63..b13f094559d 100644 --- a/mindspore/lite/tools/converter/quantizer/parameter_tunner.cc +++ b/mindspore/lite/tools/converter/quantizer/parameter_tunner.cc @@ -249,14 +249,14 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve delete origin_model; return RET_OK; } - int babysitting_rounds = 25; - step = (min_max.max - min_max.min) / babysitting_rounds; + int baby_step_rounds = 25; + step = (min_max.max - min_max.min) / baby_step_rounds; - param.rounds = babysitting_rounds; + param.rounds = baby_step_rounds; param.start_scale = start_scale; param.step = step; param.thread_num = flags->commonQuantParam.thread_num; - std::cout << "==========Search with babysitting step==============\n"; + std::cout << "==========Search with baby step==============\n"; ret = WeightQuantModelInference(func_graph, flags, origin_session, origin_model_size, param, init_scale, &candidate_scales, true); if (ret != RET_OK) {