forked from OSSInnovation/mindspore
!4151 Fix ps training precision error
Merge pull request !4151 from ZPaC/master-fix-ps-training-precision-error
This commit is contained in:
commit
9efbfb8af1
|
@ -126,6 +126,15 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr
|
|||
inputs_.push_back(momentum);
|
||||
}
|
||||
|
||||
void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) {
|
||||
size_t lr_offset = 0;
|
||||
float *lr = values.data() + lr_offset;
|
||||
auto ret = memcpy_s(inputs_[2]->addr, sizeof(float), lr, sizeof(float));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
}
|
||||
|
||||
const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; }
|
||||
|
||||
const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; }
|
||||
|
|
|
@ -82,6 +82,7 @@ class MomentumOptimInfo : public DenseOptimInfo {
|
|||
const AddressPtr &gradient, const AddressPtr &momentum);
|
||||
~MomentumOptimInfo() override = default;
|
||||
|
||||
void Update(const Values &values, const Lengths &lens) override;
|
||||
const AddressPtr &gradient();
|
||||
const AddressPtr &indices();
|
||||
size_t grad_index() override;
|
||||
|
|
Loading…
Reference in New Issue