diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc index 5f25b79c23..5801b241e4 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -126,6 +126,15 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr inputs_.push_back(momentum); } +void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) { + size_t lr_offset = 0; + float *lr = values.data() + lr_offset; + auto ret = memcpy_s(inputs_[2]->addr, sizeof(float), lr, sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } +} + const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; } const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; } diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h index dc567e023c..f59d8ad6c1 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h @@ -82,6 +82,7 @@ class MomentumOptimInfo : public DenseOptimInfo { const AddressPtr &gradient, const AddressPtr &momentum); ~MomentumOptimInfo() override = default; + void Update(const Values &values, const Lengths &lens) override; const AddressPtr &gradient(); const AddressPtr &indices(); size_t grad_index() override;