Solve the precision problem of fp16
This commit is contained in:
parent
3da5945080
commit
8d9154bfdf
|
@ -69,6 +69,8 @@ int PReLUGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
weight_length_ = weight_shape[0];
|
||||
workspace_size_ = weight_length_ * sizeof(float);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
|
@ -90,15 +92,20 @@ bool PReLUGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', output buffer memset failed. Error no: " << ret;
|
||||
}
|
||||
|
||||
auto *dw_array = reinterpret_cast<float *>(workspace[0]->addr);
|
||||
ret = memset_s(dw_array, workspace[0]->size, 0, workspace[0]->size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', workspace buffer memset failed. Error no: " << ret;
|
||||
}
|
||||
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
|
||||
auto task = [this, dy, x, w, dx, dw](size_t start, size_t end) {
|
||||
auto task = [this, dy, x, w, dx, dw, dw_array](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
size_t channel_id = weight_length_ == 1 ? 0 : (i / per_channel_length_) % weight_length_;
|
||||
T threshold = static_cast<T>(0);
|
||||
dx[i] = x[i] <= threshold ? w[channel_id] * dy[i] : dy[i];
|
||||
if (x[i] < threshold) {
|
||||
dw[channel_id] = x[i] * dy[i] + dw[channel_id];
|
||||
dw_array[channel_id] += static_cast<float>(x[i] * dy[i]);
|
||||
dw[channel_id] = static_cast<T>(dw_array[channel_id]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue