forked from mindspore-Ecosystem/mindspore
!12746 Fix ApplyAdagradCpuKernel output
From: @yang_chun Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejian
This commit is contained in:
commit
99c17c43d8
|
@ -38,9 +38,9 @@ bool ApplyAdagradCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
CheckParam(inputs, outputs);
|
||||
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs);
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs);
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -67,7 +67,8 @@ void ApplyAdagradCPUKernel::CheckParam(const std::vector<AddressPtr> &inputs, co
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs) {
|
||||
void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto var = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto accum = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto lr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
|
@ -96,6 +97,17 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs)
|
|||
for (auto &it : threads) {
|
||||
it.join();
|
||||
}
|
||||
|
||||
// Copy result to output tensor
|
||||
auto output_var = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto output_accum = reinterpret_cast<T *>(outputs[1]->addr);
|
||||
if (memcpy_s(output_var, outputs[0]->size, var, inputs[0]->size) != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Launch kernel error: memcpy failed.";
|
||||
}
|
||||
|
||||
if (memcpy_s(output_accum, outputs[1]->size, accum, inputs[1]->size) != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Launch kernel error: memcpy failed.";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -36,7 +36,7 @@ class ApplyAdagradCPUKernel : public CPUKernel {
|
|||
private:
|
||||
static void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs);
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
void LaunchApplyAdagrad(T const var, T const accum, const T lr, const T gradient, size_t start, size_t end);
|
||||
bool update_slots_{true};
|
||||
|
|
|
@ -36,8 +36,7 @@ class Net(nn.Cell):
|
|||
self.accum = Parameter(Tensor(accum_np), name="accum")
|
||||
|
||||
def construct(self, lr, grad):
|
||||
self.apply_adagrad(self.var, self.accum, lr, grad)
|
||||
return self.var, self.accum
|
||||
return self.apply_adagrad(self.var, self.accum, lr, grad)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue