Added Federated Learning API

This commit is contained in:
Emir Haleva 2021-09-01 18:24:32 +03:00
parent 5a851daf2f
commit a14eac9fb2
18 changed files with 438 additions and 15 deletions

View File

@ -45,6 +45,7 @@ class TrainCfg {
OptimizationLevel optimization_level_ = kO0;
std::string loss_name_; /**< Set part of the name that identify a loss kernel */
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
bool accumulate_gradients_ = false;
};
} // namespace mindspore

View File

@ -92,6 +92,28 @@ class MS_API Model {
/// \return The input tensor with the given name, if the name is not found, an invalid tensor is returned.
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
/// \brief Obtains all gradient tensors of the model.
///
/// \return The vector that includes all gradient tensors.
std::vector<MSTensor> GetGradients() const;
/// \brief update gradient tensors of the model.
///
/// \param[in] inputs A vector new gradients.
/// \return Status of operation
Status ApplyGradients(const std::vector<MSTensor> &gradients);
/// \brief Obtains optimizer params tensors of the model.
///
/// \return The vector that includes all params tensors.
std::vector<MSTensor> GetOptimizerParams() const;
/// \brief update the optimizer parameters
///
/// \param[in] inputs A vector new optimizer params.
/// \return Status of operation
Status SetOptimizerParams(const std::vector<MSTensor> &params);
Status InitMetrics(std::vector<Metrics *> metrics);
std::vector<Metrics *> GetMetrics();

View File

@ -202,6 +202,34 @@ class MS_API LiteSession {
/// \param[in] features new featuremap
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features) { return mindspore::lite::RET_ERROR; }
/// \brief Get model gradient
///
/// \return a vector of gradient tensors (MindSpore Lite MSTensor).
virtual std::vector<tensor::MSTensor *> GetGradients() const {
std::vector<tensor::MSTensor *> gradients;
return gradients;
}
/// \brief update model gradient
///
/// \param[in] new gradients
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) { return mindspore::lite::RET_ERROR; }
/// \brief Get model optimizer params
///
/// \return a vector of optimizer parameters (MindSpore Lite MSTensor).
virtual std::vector<tensor::MSTensor *> GetOptimizerParams() const {
std::vector<tensor::MSTensor *> params;
return params;
}
/// \brief set model optimizer params
///
/// \param[in] new optimizer params
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int SetOptimizerParams(const std::vector<tensor::MSTensor *> &params) { return mindspore::lite::RET_ERROR; }
};
} // namespace session
} // namespace mindspore

View File

@ -55,14 +55,17 @@ class TrainCfg {
TrainCfg(const TrainCfg &rhs) {
this->loss_name_ = rhs.loss_name_;
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
this->accumulate_gradients_ = rhs.accumulate_gradients_;
}
TrainCfg &operator=(const TrainCfg &rhs) {
this->loss_name_ = rhs.loss_name_;
this->mix_precision_cfg_ = rhs.mix_precision_cfg_;
this->accumulate_gradients_ = rhs.accumulate_gradients_;
return *this;
}
std::string loss_name_; /**< Set part of the name that identify a loss kernel */
MixPrecisionCfg mix_precision_cfg_; /**< Mix precision configuration */
bool accumulate_gradients_ = false; /**< If true gardents are accmulated and can be read by GetGradients */
};
} // namespace lite

View File

@ -202,6 +202,41 @@ Status Model::SetTrainMode(bool train) {
bool Model::GetTrainMode() const { return ((impl_ != nullptr) && (impl_->session_) && (impl_->session_->IsTrain())); }
std::vector<MSTensor> Model::GetGradients() const {
std::vector<MSTensor> empty;
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return empty;
}
return impl_->GetGradients();
}
Status Model::ApplyGradients(const std::vector<MSTensor> &gradients) {
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
MS_LOG(ERROR) << "Model is null.";
return kLiteUninitializedObj;
}
return impl_->ApplyGradients(gradients);
}
std::vector<MSTensor> Model::GetOptimizerParams() const {
std::vector<MSTensor> empty;
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return empty;
}
auto res = impl_->GetOptimizerParams();
return res;
}
Status Model::SetOptimizerParams(const std::vector<MSTensor> &params) {
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
MS_LOG(ERROR) << "Model is null.";
return kLiteUninitializedObj;
}
return impl_->SetOptimizerParams(params);
}
Status Model::InitMetrics(std::vector<Metrics *> metrics) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";

View File

@ -349,6 +349,82 @@ std::vector<MSTensor> ModelImpl::GetOutputs() {
return res;
}
std::vector<MSTensor> ModelImpl::GetGradients() const {
std::vector<MSTensor> empty;
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
return empty;
}
auto params = session_->GetGradients();
if (params.empty()) {
MS_LOG(ERROR) << "No optimizer parameters avelibale.";
return empty;
}
std::vector<MSTensor> res = LiteTensorsToMSTensors(params);
return res;
}
Status ModelImpl::ApplyGradients(const std::vector<MSTensor> &gradients) {
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
return kLiteNullptr;
}
if (gradients.empty()) {
MS_LOG(ERROR) << "gradients is null.";
return kLiteInputParamInvalid;
}
std::vector<tensor::MSTensor *> inner_gradients;
inner_gradients.resize(gradients.size());
for (size_t i = 0; i < gradients.size(); i++) {
auto gradient = gradients[i];
if (gradient.impl_ == nullptr || gradient.impl_->lite_tensor() == nullptr) {
MS_LOG(ERROR) << "gradient tensor " << gradient.Name() << " is null.";
return kLiteInputTensorError;
}
inner_gradients[i] = gradient.impl_->lite_tensor();
}
auto ret = session_->ApplyGradients(inner_gradients);
return static_cast<StatusCode>(ret);
}
std::vector<MSTensor> ModelImpl::GetOptimizerParams() const {
std::vector<MSTensor> empty;
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
return empty;
}
auto params = session_->GetOptimizerParams();
if (params.empty()) {
MS_LOG(ERROR) << "No optimizer parameters avelibale.";
return empty;
}
std::vector<MSTensor> res = LiteTensorsToMSTensors(params);
return res;
}
Status ModelImpl::SetOptimizerParams(const std::vector<MSTensor> &params) {
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
return kLiteNullptr;
}
if (params.empty()) {
MS_LOG(ERROR) << "params is null.";
return kLiteInputParamInvalid;
}
std::vector<tensor::MSTensor *> inner_params;
inner_params.resize(params.size());
for (size_t i = 0; i < params.size(); i++) {
auto param = params[i];
if (param.impl_ == nullptr || param.impl_->lite_tensor() == nullptr) {
MS_LOG(ERROR) << "Param tensor " << param.Name() << " is null.";
return kLiteInputTensorError;
}
inner_params[i] = param.impl_->lite_tensor();
}
auto ret = session_->SetOptimizerParams(inner_params);
return static_cast<StatusCode>(ret);
}
MSTensor ModelImpl::GetInputByTensorName(const std::string &name) {
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";

View File

@ -72,6 +72,10 @@ class ModelImpl {
Status LoadConfig(const std::string &config_path);
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
std::vector<MSTensor> GetGradients() const;
Status ApplyGradients(const std::vector<MSTensor> &gradients);
std::vector<MSTensor> GetOptimizerParams() const;
Status SetOptimizerParams(const std::vector<MSTensor> &params);
MSTensor GetInputByTensorName(const std::string &name);
std::vector<std::string> GetOutputTensorNames();
MSTensor GetOutputByTensorName(const std::string &name);

View File

@ -35,7 +35,7 @@ Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cf
l_train_cfg->mix_precision_cfg_.loss_scale_ = a_train_cfg->mix_precision_cfg_.loss_scale_;
l_train_cfg->mix_precision_cfg_.keep_batchnorm_fp32_ = (a_train_cfg->optimization_level_ != kO3);
l_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_ = a_train_cfg->mix_precision_cfg_.num_of_not_nan_iter_th_;
l_train_cfg->accumulate_gradients_ = a_train_cfg->accumulate_gradients_;
return kSuccess;
}
} // namespace mindspore

View File

@ -1,4 +1,3 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
@ -17,6 +16,7 @@
#include "src/runtime/kernel/arm/fp32_grad/adam.h"
#include <cmath>
#include <string>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -93,7 +93,9 @@ int AdamRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
auto adam_kernel = reinterpret_cast<AdamCPUKernel *>(cdata);
CHECK_NULL_RETURN(adam_kernel);
auto error_code = RET_OK;
if (adam_kernel->get_optimizer_mode() == OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
if (adam_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) {
error_code = adam_kernel->ExecuteVirtualBatch(task_id);
} else if (adam_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) {
error_code = adam_kernel->ExecuteVirtualBatch(task_id);
} else {
error_code = adam_kernel->Execute(task_id);
@ -125,6 +127,11 @@ int AdamCPUKernel::Init() {
return RET_OK;
}
std::vector<int> AdamCPUKernel::GetOptimizerParamsIdxs() const {
std::vector<int> indices = {6, 7, 3, 4, 8};
return indices;
}
int AdamCPUKernel::OptimizerStep() {
CHECK_LESS_RETURN(in_tensors_.size(), 9);
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());

View File

@ -40,6 +40,7 @@ class AdamCPUKernel : public OptimizerKernel {
int Run() override;
int Execute(int task_id);
int OptimizerStep() override;
std::vector<int> GetOptimizerParamsIdxs() const override;
private:
int thread_count_;

View File

@ -1,4 +1,3 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
@ -16,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32_grad/apply_momentum.h"
#include <string>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -74,7 +74,9 @@ int ApplyMomentumRun(void *cdata, int task_id, float lhs_scale, float rhs_scale)
CHECK_NULL_RETURN(cdata);
auto applyMomentum_kernel = reinterpret_cast<ApplyMomentumCPUKernel *>(cdata);
auto error_code = RET_OK;
if (applyMomentum_kernel->get_optimizer_mode() == OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
if (applyMomentum_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) {
error_code = applyMomentum_kernel->ExecuteVirtualBatch(task_id);
} else if (applyMomentum_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) {
error_code = applyMomentum_kernel->ExecuteVirtualBatch(task_id);
} else {
error_code = applyMomentum_kernel->Execute(task_id);
@ -111,6 +113,11 @@ int ApplyMomentumCPUKernel::Init() {
return RET_OK;
}
std::vector<int> ApplyMomentumCPUKernel::GetOptimizerParamsIdxs() const {
std::vector<int> indices = {4};
return indices;
}
int ApplyMomentumCPUKernel::OptimizerStep() {
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
CHECK_NULL_RETURN(weight);

View File

@ -42,6 +42,7 @@ class ApplyMomentumCPUKernel : public OptimizerKernel {
int Execute(int task_id);
int Run() override;
int OptimizerStep() override;
std::vector<int> GetOptimizerParamsIdxs() const override;
private:
int thread_count_;

View File

@ -1,4 +1,3 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
@ -16,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32_grad/sgd.h"
#include <string>
#include <algorithm>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -122,7 +122,9 @@ int SgdRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
auto sgd_kernel = reinterpret_cast<SgdCPUKernel *>(cdata);
CHECK_NULL_RETURN(sgd_kernel);
auto error_code = RET_OK;
if (sgd_kernel->get_optimizer_mode() == OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) {
error_code = sgd_kernel->ExecuteVirtualBatch(task_id);
} else if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::ACCUMULATE_GRADS) {
error_code = sgd_kernel->ExecuteVirtualBatch(task_id);
} else {
error_code = sgd_kernel->Execute(task_id);
@ -138,7 +140,7 @@ int SgdRunInit(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
auto sgd_kernel = reinterpret_cast<SgdCPUKernel *>(cdata);
CHECK_NULL_RETURN(sgd_kernel);
auto error_code = RET_OK;
if (sgd_kernel->get_optimizer_mode() == OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
if (sgd_kernel->get_optimizer_mode() == WeightUpdateMode::VIRTUAL_BATCH) {
error_code = sgd_kernel->ExecuteVirtualBatch(task_id);
} else {
error_code = sgd_kernel->ExecuteInit(task_id);
@ -192,6 +194,11 @@ int SgdCPUKernel::Init() {
return RET_OK;
}
std::vector<int> SgdCPUKernel::GetOptimizerParamsIdxs() const {
std::vector<int> indices = {4};
return indices;
}
int SgdCPUKernel::OptimizerStep() {
auto weight = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());

View File

@ -41,6 +41,7 @@ class SgdCPUKernel : public OptimizerKernel {
int ExecuteInit(int task_id);
int Execute(int task_id);
int OptimizerStep() override;
std::vector<int> GetOptimizerParamsIdxs() const override;
private:
int thread_count_;

View File

@ -18,7 +18,11 @@
#include <vector>
#include <cmath>
#include <cfloat>
#include <algorithm>
#include <string>
#include <iostream>
#include "src/lite_kernel.h"
#include "include/ms_tensor.h"
#include "include/errorcode.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
@ -31,6 +35,7 @@ static __attribute__((always_inline)) inline bool MS_ISNAN(float var) {
namespace mindspore::kernel {
enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH, ACCUMULATE_GRADS };
class OptimizerKernel : public InnerKernel {
public:
OptimizerKernel() = default;
@ -39,7 +44,6 @@ class OptimizerKernel : public InnerKernel {
: InnerKernel(parameter, inputs, outputs, ctx), lr_idx_(lr_idx), grad_idx_(grad_idx) {}
~OptimizerKernel() = default;
enum class WeightUpdateMode { NORMAL, VIRTUAL_BATCH };
WeightUpdateMode get_optimizer_mode() { return weight_update_mod_; }
int Init() override {
@ -55,13 +59,69 @@ class OptimizerKernel : public InnerKernel {
float GetLearningRate() { return lr_; }
virtual std::vector<int> GetOptimizerParamsIdxs() const {
std::vector<int> indices;
return indices;
}
std::vector<tensor::MSTensor *> GetOptimizerParams() const {
std::vector<tensor::MSTensor *> params;
auto indices = GetOptimizerParamsIdxs();
indices.push_back(lr_idx_);
for (size_t ix = 0; ix < indices.size(); ix++) {
auto param = lite::Tensor::CopyTensor(*in_tensors_.at(indices[ix]));
param->set_tensor_name(in_tensors_.at(indices[ix])->tensor_name());
param->set_data(static_cast<void *>(in_tensors_.at(indices[ix])->data()));
param->set_own_data(false);
if (param->data() == nullptr) {
MS_LOG(ERROR) << "Tensor: " << param->tensor_name() << "has no data";
return params;
}
params.push_back(param);
}
return params;
}
bool SetOptimizerParams(tensor::MSTensor *param) {
if (param == nullptr) {
return false;
}
bool found = false;
auto indices = GetOptimizerParamsIdxs();
indices.push_back(lr_idx_);
for (size_t ix = 0; ix < indices.size(); ix++) {
if (param->tensor_name() == in_tensors_.at(indices[ix])->tensor_name()) {
auto value = static_cast<float *>(param->MutableData())[0];
static_cast<float *>(in_tensors_.at(indices[ix])->MutableData())[0] = value;
if (lr_idx_ == indices[ix]) {
lr_ = value;
}
found = true;
break;
}
}
return found;
}
lite::Tensor *GetGradients() {
lite::Tensor *grad_sum_tensor = nullptr;
if (grad_sum_ != nullptr) {
auto shape = in_tensors_.at(grad_idx_)->shape();
grad_sum_tensor = new lite::Tensor(kNumberTypeFloat, shape);
grad_sum_tensor->set_tensor_name(in_tensors_.at(grad_idx_)->tensor_name());
grad_sum_tensor->set_data(static_cast<void *>(grad_sum_));
grad_sum_tensor->set_own_data(false);
}
return grad_sum_tensor;
}
int RestoreDefaultLearningRate() {
SetLearningRate(default_lr_);
return RET_OK;
}
int SetOptimizerMode(WeightUpdateMode mod) {
if (mod == WeightUpdateMode::VIRTUAL_BATCH) {
if (mod == WeightUpdateMode::VIRTUAL_BATCH || mod == WeightUpdateMode::ACCUMULATE_GRADS) {
if (grad_sum_ != nullptr) {
ms_context_->allocator->Free(grad_sum_);
grad_sum_ = nullptr;
@ -75,7 +135,7 @@ class OptimizerKernel : public InnerKernel {
}
valid_grad_sum_ = false;
std::fill(grad_sum_, grad_sum_ + elem_num, 0);
weight_update_mod_ = WeightUpdateMode::VIRTUAL_BATCH;
weight_update_mod_ = mod;
} else {
if (grad_sum_ != nullptr) {
OptimizerStep();
@ -141,6 +201,10 @@ class OptimizerKernel : public InnerKernel {
}
return RET_OK;
}
int set_grad_sum_valid() {
valid_grad_sum_ = true;
return RET_OK;
}
protected:
float default_lr_ = 0.0f;

View File

@ -634,6 +634,10 @@ void TrainSession::CompileOptimizedKernels() {
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
std::copy(kernel->in_tensors().begin(), kernel->in_tensors().end(), std::back_inserter(out_tensor));
if (cfg_.accumulate_gradients_) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
optimizer->SetOptimizerMode(kernel::WeightUpdateMode::ACCUMULATE_GRADS);
}
}
}
@ -677,21 +681,136 @@ float TrainSession::GetLearningRate() {
return 0.0;
}
std::vector<tensor::MSTensor *> TrainSession::GetOptimizerParams() const {
std::vector<tensor::MSTensor *> params;
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
auto kernelParams = optimizer->GetOptimizerParams();
for (size_t ix = 0; ix < kernelParams.size(); ix++) {
auto kernelParam = kernelParams[ix];
auto name = kernelParam->tensor_name();
bool found = false;
for (size_t iy = 0; iy < params.size(); iy++) {
if (params[iy]->tensor_name() == name) {
found = true;
break;
}
}
if (!found) {
params.push_back(kernelParam);
}
}
}
}
return params;
}
int TrainSession::SetOptimizerParams(const std::vector<tensor::MSTensor *> &params) {
for (size_t ix = 0; ix < params.size(); ix++) {
auto param = params[ix];
if (param == nullptr) {
MS_LOG(ERROR) << "Param tensor " << param->tensor_name() << " is null.";
return RET_ERROR;
}
bool found = false;
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
found = optimizer->SetOptimizerParams(param);
if (found) break;
}
}
if (!found) {
MS_LOG(ERROR) << "Tensor name " << param->tensor_name() << " is not a valid name.";
return RET_ERROR;
}
}
return RET_OK;
}
std::vector<tensor::MSTensor *> TrainSession::GetGradients() const {
std::vector<tensor::MSTensor *> params;
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
auto kernelGradint = optimizer->GetGradients();
if (kernelGradint != nullptr) {
params.push_back(kernelGradint);
}
}
}
return params;
}
int TrainSession::ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) {
auto current_gradients = GetGradients();
if (current_gradients.size() != gradients.size()) {
MS_LOG(ERROR) << "gradients vector has wrong size " << gradients.size() << " instead of "
<< current_gradients.size();
return RET_ERROR;
}
for (size_t ix = 0; ix < gradients.size(); ix++) {
auto gradient = gradients[ix];
if (gradient == nullptr) {
MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " is null.";
return RET_ERROR;
}
bool found = false;
for (size_t iy = 0; iy < current_gradients.size(); iy++) {
auto current_gradient = current_gradients[iy];
if (current_gradient->tensor_name() == gradient->tensor_name()) {
found = true;
if (current_gradient->Size() == gradient->Size()) {
std::copy(static_cast<char *>(gradient->data()), static_cast<char *>(gradient->data()) + gradient->Size(),
static_cast<char *>(current_gradient->MutableData()));
} else {
MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " has wrong size " << gradient->Size()
<< " instead of " << current_gradient->Size();
return RET_ERROR;
}
break;
}
}
if (!found) {
MS_LOG(ERROR) << "gradient tensor " << gradient->tensor_name() << " not found";
return RET_ERROR;
}
}
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
optimizer->set_grad_sum_valid();
auto ret = optimizer->OptimizerStep();
if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to optimize model weights";
return ret;
}
}
}
return RET_OK;
}
int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
auto mod = (virtual_batch_multiplier <= 1) ? kernel::OptimizerKernel::WeightUpdateMode::NORMAL
: kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH;
auto mod =
(virtual_batch_multiplier <= 1) ? kernel::WeightUpdateMode::NORMAL : kernel::WeightUpdateMode::VIRTUAL_BATCH;
virtual_batch_multiplier_ = (virtual_batch_multiplier <= 1) ? 0 : virtual_batch_multiplier;
virtual_batch_idx_ = 0;
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = static_cast<kernel::OptimizerKernel *>(kernel->kernel());
if (optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::NORMAL &&
optimizer->get_optimizer_mode() != kernel::WeightUpdateMode::VIRTUAL_BATCH) {
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode, conflict with accumulate grads";
return RET_ERROR;
}
auto ret = optimizer->SetOptimizerMode(mod);
if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set optimizer mode";
return RET_ERROR;
}
if (mod == kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
lr = (lr < 0.0f) ? (optimizer->GetLearningRate() / static_cast<float>(virtual_batch_multiplier_)) : lr;
ret = optimizer->SetLearningRate(lr);
} else {
@ -706,7 +825,7 @@ int TrainSession::AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr,
if (IsBN(kernel) && kernel->IsTrainable()) {
auto batchnorm = static_cast<kernel::BatchnormCPUKernel *>(kernel->kernel());
auto ret = RET_OK;
if (mod == kernel::OptimizerKernel::WeightUpdateMode::VIRTUAL_BATCH) {
if (mod == kernel::WeightUpdateMode::VIRTUAL_BATCH) {
momentum = (momentum < 0.0f) ? (batchnorm->get_momentum() / virtual_batch_multiplier_) : momentum;
ret = batchnorm->set_momentum(momentum);
} else {

View File

@ -62,6 +62,10 @@ class TrainSession : virtual public lite::LiteSession {
bool IsEval() override { return !train_mode_; }
int SetLearningRate(float learning_rate) override;
float GetLearningRate() override;
std::vector<tensor::MSTensor *> GetGradients() const override;
std::vector<tensor::MSTensor *> GetOptimizerParams() const override;
int SetOptimizerParams(const std::vector<tensor::MSTensor *> &params) override;
int ApplyGradients(const std::vector<tensor::MSTensor *> &gradients) override;
int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) override;
void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); }

View File

@ -117,4 +117,47 @@ TEST_F(TestCxxApiLiteModel, test_metrics_SUCCESS) {
ASSERT_EQ(metrics.size(), 1);
}
TEST_F(TestCxxApiLiteModel, test_getparams_SUCCESS) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
context->MutableDeviceInfo().push_back(cpu_context);
auto train_cfg = std::make_shared<TrainCfg>();
train_cfg->accumulate_gradients_ = true;
ASSERT_TRUE(Serialization::Load("./nets/conv_train_model.ms", ModelType::kMindIR, &graph) == kSuccess);
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
auto params = model.GetOptimizerParams();
ASSERT_EQ(params.size(), 2);
float pi = 3.141592647;
for (size_t ix = 0; ix < params.size(); ix++) {
static_cast<float *>(params[ix].MutableData())[0] = static_cast<float>(ix) + pi;
}
ASSERT_TRUE(model.SetOptimizerParams(params) == kSuccess);
auto params1 = model.GetOptimizerParams();
for (size_t ix = 0; ix < params1.size(); ix++) {
ASSERT_EQ(static_cast<float *>(params1[ix].MutableData())[0], static_cast<float>(ix) + pi);
}
}
TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
context->MutableDeviceInfo().push_back(cpu_context);
auto train_cfg = std::make_shared<TrainCfg>();
train_cfg->accumulate_gradients_ = true;
ASSERT_TRUE(Serialization::Load("./nets/conv_train_model.ms", ModelType::kMindIR, &graph) == kSuccess);
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
auto graients = model.GetGradients();
ASSERT_EQ(graients.size(), 2);
float pi = 3.141592647;
for (size_t ix = 0; ix < graients.size(); ix++) {
static_cast<float *>(graients[ix].MutableData())[0] = static_cast<float>(ix) + pi;
}
ASSERT_TRUE(model.ApplyGradients(graients) == kSuccess);
}
} // namespace mindspore