From 8d9154bfdf8c2b4a23986324239ef2ba15f2e847 Mon Sep 17 00:00:00 2001 From: jiajun169 Date: Wed, 9 Nov 2022 13:08:13 +0800 Subject: [PATCH] Solve the precision problem of fp16 --- .../device/cpu/kernel/prelu_grad_cpu_kernel.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/prelu_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/prelu_grad_cpu_kernel.cc index d22b6c08259..96124bcf888 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/prelu_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/prelu_grad_cpu_kernel.cc @@ -69,6 +69,8 @@ int PReLUGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st return KRET_RESIZE_FAILED; } weight_length_ = weight_shape[0]; + workspace_size_ = weight_length_ * sizeof(float); + workspace_size_list_.push_back(workspace_size_); return KRET_OK; } @@ -90,15 +92,20 @@ bool PReLUGradCpuKernelMod::LaunchKernel(const std::vector &inputs, if (ret != EOK) { MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', output buffer memset failed. Error no: " << ret; } - + auto *dw_array = reinterpret_cast(workspace[0]->addr); + ret = memset_s(dw_array, workspace[0]->size, 0, workspace[0]->size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', workspace buffer memset failed. Error no: " << ret; + } size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; - auto task = [this, dy, x, w, dx, dw](size_t start, size_t end) { + auto task = [this, dy, x, w, dx, dw, dw_array](size_t start, size_t end) { for (size_t i = start; i < end; i++) { size_t channel_id = weight_length_ == 1 ? 0 : (i / per_channel_length_) % weight_length_; T threshold = static_cast(0); dx[i] = x[i] <= threshold ? w[channel_id] * dy[i] : dy[i]; if (x[i] < threshold) { - dw[channel_id] = x[i] * dy[i] + dw[channel_id]; + dw_array[channel_id] += static_cast(x[i] * dy[i]); + dw[channel_id] = static_cast(dw_array[channel_id]); } } };