From cd749b90f13e2506a397d27dbd444065d2a2b0a5 Mon Sep 17 00:00:00 2001 From: huangbo77 Date: Wed, 14 Jul 2021 16:26:19 +0800 Subject: [PATCH] fixes SGD --- .../ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc index 8e1b4d66f45..40814707d1d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc @@ -47,7 +47,7 @@ void SGDCPUKernel::CheckParam(const std::vector &inputs, const st } template -bool SGDCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, +bool SGDCPUKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { CheckParam(inputs, outputs); @@ -57,7 +57,8 @@ bool SGDCPUKernel::Launch(const std::vector &inputs, const std::v auto accum = reinterpret_cast(inputs[3]->addr); auto momentum = reinterpret_cast(inputs[4]->addr); auto stat = reinterpret_cast(inputs[5]->addr); - size_t elem_num = inputs[0]->size / sizeof(float); + auto output_param = reinterpret_cast(outputs[0]->addr); + size_t elem_num = inputs[0]->size / sizeof(T); auto task = [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { @@ -79,6 +80,7 @@ bool SGDCPUKernel::Launch(const std::vector &inputs, const std::v } } param[i] -= lr[0] * grad_new; + output_param[i] = param[i]; } }; CPUKernelUtils::ParallelFor(task, elem_num);