forked from mindspore-Ecosystem/mindspore
Added Federated Learning API
This commit is contained in:
parent
5a851daf2f
commit
a14eac9fb2
|
@ -45,6 +45,7 @@ class TrainCfg {
|
||||||
OptimizationLevel optimization_level_ = kO0;
|
OptimizationLevel optimization_level_ = kO0;
|
||||||
std::string loss_name_; /**< Set part of the name that identify a loss kernel */
|
std::string loss_name_; /**< Set part of the name that identify a loss kernel */
|
||||||
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
|
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
|
||||||
|
bool accumulate_gradients_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -92,6 +92,28 @@ class MS_API Model {
|
||||||
/// \return The input tensor with the given name, if the name is not found, an invalid tensor is returned.
|
/// \return The input tensor with the given name, if the name is not found, an invalid tensor is returned.
|
||||||
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
|
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
|
||||||
|
|
||||||
|
/// \brief Obtains all gradient tensors of the model.
|
||||||
|
///
|
||||||
|
/// \return The vector that includes all gradient tensors.
|
||||||
|
std::vector<MSTensor> GetGradients() const;
|
||||||
|
|
||||||
|
/// \brief update gradient tensors of the model.
|
||||||
|
///
|
||||||
|
/// \param[in] inputs A vector new gradients.
|
||||||
|
/// \return Status of operation
|
||||||
|
Status ApplyGradients(const std::vector<MSTensor> &gradients);
|
||||||
|
|
||||||
|
/// \brief Obtains optimizer params tensors of the model.
|
||||||
|
///
|
||||||
|
/// \return The vector that includes all params tensors.
|
||||||
|
std::vector<MSTensor> GetOptimizerParams() const;
|
||||||
|
|
||||||
|
/// \brief update the optimizer parameters
|
||||||
|
///
|
||||||
|
/// \param[in] inputs A vector new optimizer params.
|
||||||
|
/// \return Status of operation
|
||||||
|
Status SetOptimizerParams(const std::vector<MSTensor> ¶ms);
|
||||||
|
|
||||||
Status InitMetrics(std::vector<Metrics *> metrics);
|
Status InitMetrics(std::vector<Metrics *> metrics);
|
||||||
std::vector<Metrics *> GetMetrics();
|
std::vector<Metrics *> GetMetrics();
|
||||||
|
|
||||||
|
|
|
@ -202,6 +202,34 @@ class MS_API LiteSession {
|
||||||
/// \param[in] features new featuremap
|
/// \param[in] features new featuremap
|
||||||
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||||
virtual int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features) { return mindspore::lite::RET_ERROR; }
|
virtual int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features) { return mindspore::lite::RET_ERROR; }
|
||||||
|
|
||||||
|
/// \brief Get model gradient
|
||||||
|
///
|
||||||
|
/// \return a vector of gradient tensors (MindSpore Lite MSTensor).
|
||||||
|
virtual std::vector<tensor::MSTensor *> GetGradients() const {
|
||||||
|
std::vector<tensor::MSTensor *> gradients;
|
||||||
|
return gradients;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief update model gradient
|
||||||
|
///
|
||||||
|
/// \param[in] new gradients
|
||||||
|
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||||
|
virtual int ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) { return mindspore::lite::RET_ERROR; }
|
||||||
|
|
||||||
|
/// \brief Get model optimizer params
|
||||||
|
///
|
||||||
|
/// \return a vector of optimizer parameters (MindSpore Lite MSTensor).
|
||||||
|
virtual std::vector<tensor::MSTensor *> GetOptimizerParams() const {
|
||||||
|
std::vector<tensor::MSTensor *> params;
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief set model optimizer params
|
||||||
|
///
|
||||||
|
/// \param[in] new optimizer params
|
||||||
|
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||||
|
virtual int SetOptimizerParams(const std::vector<tensor::MSTensor *> ¶ms) { return mindspore::lite::RET_ERROR; }
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -55,14 +55,17 @@ class TrainCfg {
|
||||||
TrainCfg(const TrainCfg &rhs) {
|
TrainCfg(const TrainCfg &rhs) {
|
||||||
this->loss_name_ = rhs.loss_name_;
|
this->loss_name_ = rhs.loss_name_;
|
||||||
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
|
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
|
||||||
|
this->accumulate_gradients_ = rhs.accumulate_gradients_;
|
||||||
}
|
}
|
||||||
TrainCfg &operator=(const TrainCfg &rhs) {
|
TrainCfg &operator=(const TrainCfg &rhs) {
|
||||||
this->loss_name_ = rhs.loss_name_;
|
this->loss_name_ = rhs.loss_name_;
|
||||||
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
|
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
|
||||||
|
this->accumulate_gradients_ = rhs.accumulate_gradients_;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
std::string loss_name_; /**< Set part of the name that identify a loss kernel */
|
std::string loss_name_; /**< Set part of the name that identify a loss kernel */
|
||||||
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
|
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
|
||||||
|
bool accumulate_gradients_ = false; /**< If true gardents are accmulated and can be read by GetGradients */
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
|
|
|
@ -202,6 +202,41 @@ Status Model::SetTrainMode(bool train) {
|
||||||
|
|
||||||
bool Model::GetTrainMode() const { return ((impl_ != nullptr) && (impl_->session_) && (impl_->session_->IsTrain())); }
|
bool Model::GetTrainMode() const { return ((impl_ != nullptr) && (impl_->session_) && (impl_->session_->IsTrain())); }
|
||||||
|
|
||||||
|
std::vector<MSTensor> Model::GetGradients() const {
|
||||||
|
std::vector<MSTensor> empty;
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
return impl_->GetGradients();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Model::ApplyGradients(const std::vector<MSTensor> &gradients) {
|
||||||
|
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
|
||||||
|
MS_LOG(ERROR) << "Model is null.";
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
return impl_->ApplyGradients(gradients);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<MSTensor> Model::GetOptimizerParams() const {
|
||||||
|
std::vector<MSTensor> empty;
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
auto res = impl_->GetOptimizerParams();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Model::SetOptimizerParams(const std::vector<MSTensor> ¶ms) {
|
||||||
|
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
|
||||||
|
MS_LOG(ERROR) << "Model is null.";
|
||||||
|
return kLiteUninitializedObj;
|
||||||
|
}
|
||||||
|
return impl_->SetOptimizerParams(params);
|
||||||
|
}
|
||||||
|
|
||||||
Status Model::InitMetrics(std::vector<Metrics *> metrics) {
|
Status Model::InitMetrics(std::vector<Metrics *> metrics) {
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Model implement is null.";
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
|
|
@ -349,6 +349,82 @@ std::vector<MSTensor> ModelImpl::GetOutputs() {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<MSTensor> ModelImpl::GetGradients() const {
|
||||||
|
std::vector<MSTensor> empty;
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session is null.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
auto params = session_->GetGradients();
|
||||||
|
if (params.empty()) {
|
||||||
|
MS_LOG(ERROR) << "No optimizer parameters avelibale.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
std::vector<MSTensor> res = LiteTensorsToMSTensors(params);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ModelImpl::ApplyGradients(const std::vector<MSTensor> &gradients) {
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session is null.";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
if (gradients.empty()) {
|
||||||
|
MS_LOG(ERROR) << "gradients is null.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
|
std::vector<tensor::MSTensor *> inner_gradients;
|
||||||
|
inner_gradients.resize(gradients.size());
|
||||||
|
for (size_t i = 0; i < gradients.size(); i++) {
|
||||||
|
auto gradient = gradients[i];
|
||||||
|
if (gradient.impl_ == nullptr || gradient.impl_->lite_tensor() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "gradient tensor " << gradient.Name() << " is null.";
|
||||||
|
return kLiteInputTensorError;
|
||||||
|
}
|
||||||
|
inner_gradients[i] = gradient.impl_->lite_tensor();
|
||||||
|
}
|
||||||
|
auto ret = session_->ApplyGradients(inner_gradients);
|
||||||
|
return static_cast<StatusCode>(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<MSTensor> ModelImpl::GetOptimizerParams() const {
|
||||||
|
std::vector<MSTensor> empty;
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session is null.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
auto params = session_->GetOptimizerParams();
|
||||||
|
if (params.empty()) {
|
||||||
|
MS_LOG(ERROR) << "No optimizer parameters avelibale.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
std::vector<MSTensor> res = LiteTensorsToMSTensors(params);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ModelImpl::SetOptimizerParams(const std::vector<MSTensor> ¶ms) {
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session is null.";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
if (params.empty()) {
|
||||||
|
MS_LOG(ERROR) << "params is null.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
|
std::vector<tensor::MSTensor *> inner_params;
|
||||||
|
inner_params.resize(params.size());
|
||||||
|
for (size_t i = 0; i < params.size(); i++) {
|
||||||
|
auto param = params[i];
|
||||||
|
if (param.impl_ == nullptr || param.impl_->lite_tensor() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Param tensor " << param.Name() << " is null.";
|
||||||
|
return kLiteInputTensorError;
|
||||||
|
}
|
||||||
|
inner_params[i] = param.impl_->lite_tensor();
|
||||||
|
}
|
||||||
|
auto ret = session_->SetOptimizerParams(inner_params);
|
||||||
|
return static_cast<StatusCode>(ret);
|
||||||
|
}
|
||||||
|
|
||||||
MSTensor ModelImpl::GetInputByTensorName(const std::string &name) {
|
MSTensor ModelImpl::GetInputByTensorName(const std::string &name) {
|
||||||
if (session_ == nullptr) {
|
if (session_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Session is null.";
|
MS_LOG(ERROR) << "Session is null.";
|
||||||
|
|
|
@ -72,6 +72,10 @@ class ModelImpl {
|
||||||
Status LoadConfig(const std::string &config_path);
|
Status LoadConfig(const std::string &config_path);
|
||||||
std::vector<MSTensor> GetInputs();
|
std::vector<MSTensor> GetInputs();
|
||||||
std::vector<MSTensor> GetOutputs();
|
std::vector<MSTensor> GetOutputs();
|
||||||
|
std::vector<MSTensor> GetGradients() const;
|
||||||
|
Status ApplyGradients(const std::vector<MSTensor> &gradients);
|
||||||
|
std::vector<MSTensor> GetOptimizerParams() const;
|
||||||
|
Status SetOptimizerParams(const std::vector<MSTensor> ¶ms);
|
||||||
MSTensor GetInputByTensorName(const std::string &name);
|
MSTensor GetInputByTensorName(const std::string &name);
|
||||||
std::vector<std::string> GetOutputTensorNames();
|
std::vector<std::string> GetOutputTensorNames();
|
||||||
MSTensor GetOutputByTensorName(const std::string &name);
|
MSTensor GetOutputByTensorName(const std::string &name);
|
||||||
|
|
|
@ -35,7 +35,7 @@ Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cf
|
||||||
l_train_cfg->mix_precision_cfg_.loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
|
l_train_cfg->mix_precision_cfg_.loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
|
||||||
l_train_cfg->mix_precision_cfg_.keep_batchnorm_fp32_ = (a_train_cfg->optimization_level_ != kO3);
|
l_train_cfg->mix_precision_cfg_.keep_batchnorm_fp32_ = (a_train_cfg->optimization_level_ != kO3);
|
||||||
l_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_ = a_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_;
|
l_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_ = a_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_;
|
||||||
|
l_train_cfg->accumulate_gradients_ = a_train_cfg->accumulate_gradients_;
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
|
@ -17,6 +16,7 @@
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/fp32_grad/adam.h"
|
#include "src/runtime/kernel/arm/fp32_grad/adam.h"
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <string>
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
@ -93,7 +93,9 @@ int AdamRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||||
auto adam_kernel = reinterpret_cast<AdamCPUKernel *>(cdata);
|
auto adam_kernel = reinterpret_cast<AdamCPUKernel *>(cdata);
|
||||||
CHECK_NULL_RETURN(adam_kernel);
|
CHECK_NULL_RETURN(adam_kernel);
|
||||||
auto error_code = RET_OK;
|
auto error_code = RET_OK;
|
||||||
if (adam_kernel->get_optimizer_mode() == OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
if (adam_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) {
|
||||||
|
error_code = adam_kernel->ExecuteVirtualBatch(task_id);
|
||||||
|
} else if (adam_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) {
|
||||||
error_code = adam_kernel->ExecuteVirtualBatch(task_id);
|
error_code = adam_kernel->ExecuteVirtualBatch(task_id);
|
||||||
} else {
|
} else {
|
||||||
error_code = adam_kernel->Execute(task_id);
|
error_code = adam_kernel->Execute(task_id);
|
||||||
|
@ -125,6 +127,11 @@ int AdamCPUKernel::Init() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> AdamCPUKernel::GetOptimizerParamsIdxs() const {
|
||||||
|
std::vector<int> indices = {6, 7, 3, 4, 8};
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
int AdamCPUKernel::OptimizerStep() {
|
int AdamCPUKernel::OptimizerStep() {
|
||||||
CHECK_LESS_RETURN(in_tensors_.size(), 9);
|
CHECK_LESS_RETURN(in_tensors_.size(), 9);
|
||||||
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||||
|
|
|
@ -40,6 +40,7 @@ class AdamCPUKernel : public OptimizerKernel {
|
||||||
int Run() override;
|
int Run() override;
|
||||||
int Execute(int task_id);
|
int Execute(int task_id);
|
||||||
int OptimizerStep() override;
|
int OptimizerStep() override;
|
||||||
|
std::vector<int> GetOptimizerParamsIdxs() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int thread_count_;
|
int thread_count_;
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
|
@ -16,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/fp32_grad/apply_momentum.h"
|
#include "src/runtime/kernel/arm/fp32_grad/apply_momentum.h"
|
||||||
|
#include <string>
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
@ -74,7 +74,9 @@ int ApplyMomentumRun(void *cdata, int task_id, float lhs_scale, float rhs_scale)
|
||||||
CHECK_NULL_RETURN(cdata);
|
CHECK_NULL_RETURN(cdata);
|
||||||
auto applyMomentum_kernel = reinterpret_cast<ApplyMomentumCPUKernel *>(cdata);
|
auto applyMomentum_kernel = reinterpret_cast<ApplyMomentumCPUKernel *>(cdata);
|
||||||
auto error_code = RET_OK;
|
auto error_code = RET_OK;
|
||||||
if (applyMomentum_kernel->get_optimizer_mode() == OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
if (applyMomentum_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) {
|
||||||
|
error_code = applyMomentum_kernel->ExecuteVirtualBatch(task_id);
|
||||||
|
} else if (applyMomentum_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) {
|
||||||
error_code = applyMomentum_kernel->ExecuteVirtualBatch(task_id);
|
error_code = applyMomentum_kernel->ExecuteVirtualBatch(task_id);
|
||||||
} else {
|
} else {
|
||||||
error_code = applyMomentum_kernel->Execute(task_id);
|
error_code = applyMomentum_kernel->Execute(task_id);
|
||||||
|
@ -111,6 +113,11 @@ int ApplyMomentumCPUKernel::Init() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> ApplyMomentumCPUKernel::GetOptimizerParamsIdxs() const {
|
||||||
|
std::vector<int> indices = {4};
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
int ApplyMomentumCPUKernel::OptimizerStep() {
|
int ApplyMomentumCPUKernel::OptimizerStep() {
|
||||||
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
|
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
|
||||||
CHECK_NULL_RETURN(weight);
|
CHECK_NULL_RETURN(weight);
|
||||||
|
|
|
@ -42,6 +42,7 @@ class ApplyMomentumCPUKernel : public OptimizerKernel {
|
||||||
int Execute(int task_id);
|
int Execute(int task_id);
|
||||||
int Run() override;
|
int Run() override;
|
||||||
int OptimizerStep() override;
|
int OptimizerStep() override;
|
||||||
|
std::vector<int> GetOptimizerParamsIdxs() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int thread_count_;
|
int thread_count_;
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
|
@ -16,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/fp32_grad/sgd.h"
|
#include "src/runtime/kernel/arm/fp32_grad/sgd.h"
|
||||||
|
#include <string>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
|
@ -122,7 +122,9 @@ int SgdRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||||
auto sgd_kernel = reinterpret_cast<SgdCPUKernel *>(cdata);
|
auto sgd_kernel = reinterpret_cast<SgdCPUKernel *>(cdata);
|
||||||
CHECK_NULL_RETURN(sgd_kernel);
|
CHECK_NULL_RETURN(sgd_kernel);
|
||||||
auto error_code = RET_OK;
|
auto error_code = RET_OK;
|
||||||
if (sgd_kernel->get_optimizer_mode() == OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) {
|
||||||
|
error_code = sgd_kernel->ExecuteVirtualBatch(task_id);
|
||||||
|
} else if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) {
|
||||||
error_code = sgd_kernel->ExecuteVirtualBatch(task_id);
|
error_code = sgd_kernel->ExecuteVirtualBatch(task_id);
|
||||||
} else {
|
} else {
|
||||||
error_code = sgd_kernel->Execute(task_id);
|
error_code = sgd_kernel->Execute(task_id);
|
||||||
|
@ -138,7 +140,7 @@ int SgdRunInit(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||||
auto sgd_kernel = reinterpret_cast<SgdCPUKernel *>(cdata);
|
auto sgd_kernel = reinterpret_cast<SgdCPUKernel *>(cdata);
|
||||||
CHECK_NULL_RETURN(sgd_kernel);
|
CHECK_NULL_RETURN(sgd_kernel);
|
||||||
auto error_code = RET_OK;
|
auto error_code = RET_OK;
|
||||||
if (sgd_kernel->get_optimizer_mode() == OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) {
|
||||||
error_code = sgd_kernel->ExecuteVirtualBatch(task_id);
|
error_code = sgd_kernel->ExecuteVirtualBatch(task_id);
|
||||||
} else {
|
} else {
|
||||||
error_code = sgd_kernel->ExecuteInit(task_id);
|
error_code = sgd_kernel->ExecuteInit(task_id);
|
||||||
|
@ -192,6 +194,11 @@ int SgdCPUKernel::Init() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> SgdCPUKernel::GetOptimizerParamsIdxs() const {
|
||||||
|
std::vector<int> indices = {4};
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
int SgdCPUKernel::OptimizerStep() {
|
int SgdCPUKernel::OptimizerStep() {
|
||||||
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,7 @@ class SgdCPUKernel : public OptimizerKernel {
|
||||||
int ExecuteInit(int task_id);
|
int ExecuteInit(int task_id);
|
||||||
int Execute(int task_id);
|
int Execute(int task_id);
|
||||||
int OptimizerStep() override;
|
int OptimizerStep() override;
|
||||||
|
std::vector<int> GetOptimizerParamsIdxs() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int thread_count_;
|
int thread_count_;
|
||||||
|
|
|
@ -18,7 +18,11 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
#include "src/lite_kernel.h"
|
#include "src/lite_kernel.h"
|
||||||
|
#include "include/ms_tensor.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
using mindspore::lite::RET_ERROR;
|
using mindspore::lite::RET_ERROR;
|
||||||
using mindspore::lite::RET_OK;
|
using mindspore::lite::RET_OK;
|
||||||
|
@ -31,6 +35,7 @@ static __attribute__((always_inline)) inline bool MS_ISNAN(float var) {
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
|
|
||||||
|
enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS };
|
||||||
class OptimizerKernel : public InnerKernel {
|
class OptimizerKernel : public InnerKernel {
|
||||||
public:
|
public:
|
||||||
OptimizerKernel() = default;
|
OptimizerKernel() = default;
|
||||||
|
@ -39,7 +44,6 @@ class OptimizerKernel : public InnerKernel {
|
||||||
: InnerKernel(parameter, inputs, outputs, ctx), lr_idx_(lr_idx), grad_idx_(grad_idx) {}
|
: InnerKernel(parameter, inputs, outputs, ctx), lr_idx_(lr_idx), grad_idx_(grad_idx) {}
|
||||||
~OptimizerKernel() = default;
|
~OptimizerKernel() = default;
|
||||||
|
|
||||||
enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH };
|
|
||||||
WeightUpdateMode get_optimizer_mode() { return weight_update_mod_; }
|
WeightUpdateMode get_optimizer_mode() { return weight_update_mod_; }
|
||||||
|
|
||||||
int Init() override {
|
int Init() override {
|
||||||
|
@ -55,13 +59,69 @@ class OptimizerKernel : public InnerKernel {
|
||||||
|
|
||||||
float GetLearningRate() { return lr_; }
|
float GetLearningRate() { return lr_; }
|
||||||
|
|
||||||
|
virtual std::vector<int> GetOptimizerParamsIdxs() const {
|
||||||
|
std::vector<int> indices;
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<tensor::MSTensor *> GetOptimizerParams() const {
|
||||||
|
std::vector<tensor::MSTensor *> params;
|
||||||
|
auto indices = GetOptimizerParamsIdxs();
|
||||||
|
indices.push_back(lr_idx_);
|
||||||
|
for (size_t ix = 0; ix < indices.size(); ix++) {
|
||||||
|
auto param = lite::Tensor::CopyTensor(*in_tensors_.at(indices[ix]));
|
||||||
|
param->set_tensor_name(in_tensors_.at(indices[ix])->tensor_name());
|
||||||
|
param->set_data(static_cast<void *>(in_tensors_.at(indices[ix])->data()));
|
||||||
|
param->set_own_data(false);
|
||||||
|
if (param->data() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Tensor: " << param->tensor_name() << "has no data";
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
params.push_back(param);
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SetOptimizerParams(tensor::MSTensor *param) {
|
||||||
|
if (param == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool found = false;
|
||||||
|
auto indices = GetOptimizerParamsIdxs();
|
||||||
|
indices.push_back(lr_idx_);
|
||||||
|
for (size_t ix = 0; ix < indices.size(); ix++) {
|
||||||
|
if (param->tensor_name() == in_tensors_.at(indices[ix])->tensor_name()) {
|
||||||
|
auto value = static_cast<float *>(param->MutableData())[0];
|
||||||
|
static_cast<float *>(in_tensors_.at(indices[ix])->MutableData())[0] = value;
|
||||||
|
if (lr_idx_ == indices[ix]) {
|
||||||
|
lr_ = value;
|
||||||
|
}
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return found;
|
||||||
|
}
|
||||||
|
|
||||||
|
lite::Tensor *GetGradients() {
|
||||||
|
lite::Tensor *grad_sum_tensor = nullptr;
|
||||||
|
if (grad_sum_ != nullptr) {
|
||||||
|
auto shape = in_tensors_.at(grad_idx_)->shape();
|
||||||
|
grad_sum_tensor = new lite::Tensor(kNumberTypeFloat, shape);
|
||||||
|
grad_sum_tensor->set_tensor_name(in_tensors_.at(grad_idx_)->tensor_name());
|
||||||
|
grad_sum_tensor->set_data(static_cast<void *>(grad_sum_));
|
||||||
|
grad_sum_tensor->set_own_data(false);
|
||||||
|
}
|
||||||
|
return grad_sum_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
int RestoreDefaultLearningRate() {
|
int RestoreDefaultLearningRate() {
|
||||||
SetLearningRate(default_lr_);
|
SetLearningRate(default_lr_);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int SetOptimizerMode(WeightUpdateMode mod) {
|
int SetOptimizerMode(WeightUpdateMode mod) {
|
||||||
if (mod == WeightUpdateMode::VIRTUAL_BATCH) {
|
if (mod == WeightUpdateMode::VIRTUAL_BATCH || mod == WeightUpdateMode::ACCUMULATE_GRADS) {
|
||||||
if (grad_sum_ != nullptr) {
|
if (grad_sum_ != nullptr) {
|
||||||
ms_context_->allocator->Free(grad_sum_);
|
ms_context_->allocator->Free(grad_sum_);
|
||||||
grad_sum_ = nullptr;
|
grad_sum_ = nullptr;
|
||||||
|
@ -75,7 +135,7 @@ class OptimizerKernel : public InnerKernel {
|
||||||
}
|
}
|
||||||
valid_grad_sum_ = false;
|
valid_grad_sum_ = false;
|
||||||
std::fill(grad_sum_, grad_sum_ + elem_num, 0);
|
std::fill(grad_sum_, grad_sum_ + elem_num, 0);
|
||||||
weight_update_mod_ = WeightUpdateMode::VIRTUAL_BATCH;
|
weight_update_mod_ = mod;
|
||||||
} else {
|
} else {
|
||||||
if (grad_sum_ != nullptr) {
|
if (grad_sum_ != nullptr) {
|
||||||
OptimizerStep();
|
OptimizerStep();
|
||||||
|
@ -141,6 +201,10 @@ class OptimizerKernel : public InnerKernel {
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
int set_grad_sum_valid() {
|
||||||
|
valid_grad_sum_ = true;
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
float default_lr_ = 0.0f;
|
float default_lr_ = 0.0f;
|
||||||
|
|
|
@ -634,6 +634,10 @@ void TrainSession::CompileOptimizedKernels() {
|
||||||
for (auto kernel : this->train_kernels_) {
|
for (auto kernel : this->train_kernels_) {
|
||||||
if (IsOptimizer(kernel)) {
|
if (IsOptimizer(kernel)) {
|
||||||
std::copy(kernel->in_tensors().begin(), kernel->in_tensors().end(), std::back_inserter(out_tensor));
|
std::copy(kernel->in_tensors().begin(), kernel->in_tensors().end(), std::back_inserter(out_tensor));
|
||||||
|
if (cfg_.accumulate_gradients_) {
|
||||||
|
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||||
|
optimizer->SetOptimizerMode(kernel::WeightUpdateMode::ACCUMULATE_GRADS);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -677,21 +681,136 @@ float TrainSession::GetLearningRate() {
|
||||||
return 0.0;
|
return 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<tensor::MSTensor *> TrainSession::GetOptimizerParams() const {
|
||||||
|
std::vector<tensor::MSTensor *> params;
|
||||||
|
for (auto kernel : this->train_kernels_) {
|
||||||
|
if (IsOptimizer(kernel)) {
|
||||||
|
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||||
|
auto kernelParams = optimizer->GetOptimizerParams();
|
||||||
|
for (size_t ix = 0; ix < kernelParams.size(); ix++) {
|
||||||
|
auto kernelParam = kernelParams[ix];
|
||||||
|
auto name = kernelParam->tensor_name();
|
||||||
|
bool found = false;
|
||||||
|
for (size_t iy = 0; iy < params.size(); iy++) {
|
||||||
|
if (params[iy]->tensor_name() == name) {
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
params.push_back(kernelParam);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TrainSession::SetOptimizerParams(const std::vector<tensor::MSTensor *> ¶ms) {
|
||||||
|
for (size_t ix = 0; ix < params.size(); ix++) {
|
||||||
|
auto param = params[ix];
|
||||||
|
if (param == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Param tensor " << param->tensor_name() << " is null.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
bool found = false;
|
||||||
|
for (auto kernel : this->train_kernels_) {
|
||||||
|
if (IsOptimizer(kernel)) {
|
||||||
|
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||||
|
found = optimizer->SetOptimizerParams(param);
|
||||||
|
if (found) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
MS_LOG(ERROR) << "Tensor name " << param->tensor_name() << " is not a valid name.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<tensor::MSTensor *> TrainSession::GetGradients() const {
|
||||||
|
std::vector<tensor::MSTensor *> params;
|
||||||
|
for (auto kernel : this->train_kernels_) {
|
||||||
|
if (IsOptimizer(kernel)) {
|
||||||
|
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||||
|
auto kernelGradint = optimizer->GetGradients();
|
||||||
|
if (kernelGradint != nullptr) {
|
||||||
|
params.push_back(kernelGradint);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TrainSession::ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) {
|
||||||
|
auto current_gradients = GetGradients();
|
||||||
|
if (current_gradients.size() != gradients.size()) {
|
||||||
|
MS_LOG(ERROR) << "gradients vector has wrong size " << gradients.size() << " instead of "
|
||||||
|
<< current_gradients.size();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
for (size_t ix = 0; ix < gradients.size(); ix++) {
|
||||||
|
auto gradient = gradients[ix];
|
||||||
|
if (gradient == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " is null.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
bool found = false;
|
||||||
|
for (size_t iy = 0; iy < current_gradients.size(); iy++) {
|
||||||
|
auto current_gradient = current_gradients[iy];
|
||||||
|
if (current_gradient->tensor_name() == gradient->tensor_name()) {
|
||||||
|
found = true;
|
||||||
|
if (current_gradient->Size() == gradient->Size()) {
|
||||||
|
std::copy(static_cast<char *>(gradient->data()), static_cast<char *>(gradient->data()) + gradient->Size(),
|
||||||
|
static_cast<char *>(current_gradient->MutableData()));
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " has wrong size " << gradient->Size()
|
||||||
|
<< " instead of " << current_gradient->Size();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " not found";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto kernel : this->train_kernels_) {
|
||||||
|
if (IsOptimizer(kernel)) {
|
||||||
|
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||||
|
optimizer->set_grad_sum_valid();
|
||||||
|
auto ret = optimizer->OptimizerStep();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "failed to optimize model weights";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
|
int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
|
||||||
auto mod = (virtual_batch_multiplier <= 1) ? kernel::OptimizerKernel::WeightUpdateMode::NORMAL
|
auto mod =
|
||||||
: kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH;
|
(virtual_batch_multiplier <= 1) ? kernel::WeightUpdateMode::NORMAL : kernel::WeightUpdateMode::VIRTUAL_BATCH;
|
||||||
virtual_batch_multiplier_ = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
|
virtual_batch_multiplier_ = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
|
||||||
virtual_batch_idx_ = 0;
|
virtual_batch_idx_ = 0;
|
||||||
|
|
||||||
for (auto kernel : this->train_kernels_) {
|
for (auto kernel : this->train_kernels_) {
|
||||||
if (IsOptimizer(kernel)) {
|
if (IsOptimizer(kernel)) {
|
||||||
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
|
||||||
|
if (optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::NORMAL &&
|
||||||
|
optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
||||||
|
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode, conflict with accumulate grads";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
auto ret = optimizer->SetOptimizerMode(mod);
|
auto ret = optimizer->SetOptimizerMode(mod);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
|
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (mod == kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
||||||
lr = (lr < 0.0f) ? (optimizer->GetLearningRate() / static_cast<float>(virtual_batch_multiplier_)) : lr;
|
lr = (lr < 0.0f) ? (optimizer->GetLearningRate() / static_cast<float>(virtual_batch_multiplier_)) : lr;
|
||||||
ret = optimizer->SetLearningRate(lr);
|
ret = optimizer->SetLearningRate(lr);
|
||||||
} else {
|
} else {
|
||||||
|
@ -706,7 +825,7 @@ int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr,
|
||||||
if (IsBN(kernel) && kernel->IsTrainable()) {
|
if (IsBN(kernel) && kernel->IsTrainable()) {
|
||||||
auto batchnorm = static_cast<kernel::BatchnormCPUKernel *>(kernel->kernel());
|
auto batchnorm = static_cast<kernel::BatchnormCPUKernel *>(kernel->kernel());
|
||||||
auto ret = RET_OK;
|
auto ret = RET_OK;
|
||||||
if (mod == kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
|
||||||
momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum;
|
momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum;
|
||||||
ret = batchnorm->set_momentum(momentum);
|
ret = batchnorm->set_momentum(momentum);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -62,6 +62,10 @@ class TrainSession : virtual public lite::LiteSession {
|
||||||
bool IsEval() override { return !train_mode_; }
|
bool IsEval() override { return !train_mode_; }
|
||||||
int SetLearningRate(float learning_rate) override;
|
int SetLearningRate(float learning_rate) override;
|
||||||
float GetLearningRate() override;
|
float GetLearningRate() override;
|
||||||
|
std::vector<tensor::MSTensor *> GetGradients() const override;
|
||||||
|
std::vector<tensor::MSTensor *> GetOptimizerParams() const override;
|
||||||
|
int SetOptimizerParams(const std::vector<tensor::MSTensor *> ¶ms) override;
|
||||||
|
int ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) override;
|
||||||
int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) override;
|
int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) override;
|
||||||
|
|
||||||
void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); }
|
void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); }
|
||||||
|
|
|
@ -117,4 +117,47 @@ TEST_F(TestCxxApiLiteModel, test_metrics_SUCCESS) {
|
||||||
ASSERT_EQ(metrics.size(), 1);
|
ASSERT_EQ(metrics.size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_getparams_SUCCESS) {
|
||||||
|
Model model;
|
||||||
|
Graph graph;
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
context->MutableDeviceInfo().push_back(cpu_context);
|
||||||
|
auto train_cfg = std::make_shared<TrainCfg>();
|
||||||
|
train_cfg->accumulate_gradients_ = true;
|
||||||
|
|
||||||
|
ASSERT_TRUE(Serialization::Load("./nets/conv_train_model.ms", ModelType::kMindIR, &graph) == kSuccess);
|
||||||
|
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||||
|
auto params = model.GetOptimizerParams();
|
||||||
|
ASSERT_EQ(params.size(), 2);
|
||||||
|
float pi = 3.141592647;
|
||||||
|
for (size_t ix = 0; ix < params.size(); ix++) {
|
||||||
|
static_cast<float *>(params[ix].MutableData())[0] = static_cast<float>(ix) + pi;
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(model.SetOptimizerParams(params) == kSuccess);
|
||||||
|
auto params1 = model.GetOptimizerParams();
|
||||||
|
for (size_t ix = 0; ix < params1.size(); ix++) {
|
||||||
|
ASSERT_EQ(static_cast<float *>(params1[ix].MutableData())[0], static_cast<float>(ix) + pi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) {
|
||||||
|
Model model;
|
||||||
|
Graph graph;
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
context->MutableDeviceInfo().push_back(cpu_context);
|
||||||
|
auto train_cfg = std::make_shared<TrainCfg>();
|
||||||
|
train_cfg->accumulate_gradients_ = true;
|
||||||
|
|
||||||
|
ASSERT_TRUE(Serialization::Load("./nets/conv_train_model.ms", ModelType::kMindIR, &graph) == kSuccess);
|
||||||
|
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||||
|
auto graients = model.GetGradients();
|
||||||
|
ASSERT_EQ(graients.size(), 2);
|
||||||
|
float pi = 3.141592647;
|
||||||
|
for (size_t ix = 0; ix < graients.size(); ix++) {
|
||||||
|
static_cast<float *>(graients[ix].MutableData())[0] = static_cast<float>(ix) + pi;
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(model.ApplyGradients(graients) == kSuccess);
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue