Solve the precision problem of fp16

This commit is contained in:
jiajun169 2022-11-09 13:08:13 +08:00
parent 3da5945080
commit 8d9154bfdf
1 changed files with 10 additions and 3 deletions

View File

@ -69,6 +69,8 @@ int PReLUGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
return KRET_RESIZE_FAILED;
}
weight_length_ = weight_shape[0];
workspace_size_ = weight_length_ * sizeof(float);
workspace_size_list_.push_back(workspace_size_);
return KRET_OK;
}
@ -90,15 +92,20 @@ bool PReLUGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
if (ret != EOK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', output buffer memset failed. Error no: " << ret;
}
auto *dw_array = reinterpret_cast<float *>(workspace[0]->addr);
ret = memset_s(dw_array, workspace[0]->size, 0, workspace[0]->size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', workspace buffer memset failed. Error no: " << ret;
}
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
auto task = [this, dy, x, w, dx, dw](size_t start, size_t end) {
auto task = [this, dy, x, w, dx, dw, dw_array](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
size_t channel_id = weight_length_ == 1 ? 0 : (i / per_channel_length_) % weight_length_;
T threshold = static_cast<T>(0);
dx[i] = x[i] <= threshold ? w[channel_id] * dy[i] : dy[i];
if (x[i] < threshold) {
dw[channel_id] = x[i] * dy[i] + dw[channel_id];
dw_array[channel_id] += static_cast<float>(x[i] * dy[i]);
dw[channel_id] = static_cast<T>(dw_array[channel_id]);
}
}
};