forked from mindspore-Ecosystem/mindspore
fix updated gamma beta not same with that of torch
This commit is contained in:
parent
6dab729ea1
commit
206762c935
|
@ -61,30 +61,30 @@ void CopyMemDevice2Device(const size_t N, const size_t C, float *gamma_addr, flo
|
|||
}
|
||||
|
||||
__global__ void ComputeMeanKernel(const size_t thread_num, const size_t N, const size_t C,
|
||||
float *save_mean_addr, float *save_var_addr) {
|
||||
float *dgamma, float *dbeta, const float *ws_dgamma, const float *ws_dbeta) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) {
|
||||
size_t cur_addr = pos / C;
|
||||
size_t cur_local_index = pos % C;
|
||||
float tmp = 0;
|
||||
if (cur_addr) {
|
||||
for (size_t i = 0; i < N; i++) {
|
||||
tmp += save_var_addr[i * C + cur_local_index];
|
||||
tmp += ws_dgamma[i * C + cur_local_index];
|
||||
}
|
||||
save_var_addr[cur_local_index] = tmp / N;
|
||||
dgamma[cur_local_index] = tmp;
|
||||
} else {
|
||||
for (size_t i = 0; i < N; i++) {
|
||||
tmp += save_mean_addr[i * C + cur_local_index];
|
||||
tmp += ws_dbeta[i * C + cur_local_index];
|
||||
}
|
||||
save_mean_addr[cur_local_index] = tmp / N;
|
||||
dbeta[cur_local_index] = tmp;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ComputeMean(const size_t N, const size_t C,
|
||||
float *save_mean_addr, float *save_var_addr,
|
||||
float *dgamma, float *dbeta, const float *ws_dgamma, const float *ws_dbeta,
|
||||
cudaStream_t cuda_stream) {
|
||||
size_t thread_num = C * 2;
|
||||
ComputeMeanKernel<<<GET_BLOCKS(thread_num), GET_THREADS, 0, cuda_stream>>>(
|
||||
thread_num, N, C, save_mean_addr, save_var_addr);
|
||||
thread_num, N, C, dgamma, dbeta, ws_dgamma, ws_dbeta);
|
||||
}
|
||||
|
|
|
@ -22,6 +22,6 @@ void CopyMemDevice2Device(const size_t N, const size_t C,
|
|||
float *gamma_addr, float *beta_addr, float *runing_mean_addr, float *runnig_variance_addr,
|
||||
float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var,
|
||||
cudaStream_t cuda_stream);
|
||||
void ComputeMean(const size_t N, const size_t C, float *save_mean_addr, float *save_var_addr,
|
||||
cudaStream_t cuda_stream);
|
||||
void ComputeMean(const size_t N, const size_t C, float *dgamma, float *dbeta, const float *ws_dgamma,
|
||||
const float *ws_dbeta, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_
|
||||
|
|
|
@ -76,9 +76,11 @@ class InstanceNormGradGpuKernel : public GpuKernel {
|
|||
T *dz = nullptr;
|
||||
|
||||
float *ws_gamma = GetDeviceAddress<float>(workspace, 0);
|
||||
float *ws_dgamma = GetDeviceAddress<float>(workspace, 1);
|
||||
float *ws_dbeta = GetDeviceAddress<float>(workspace, 2);
|
||||
void *workspace_addr = nullptr;
|
||||
if (workspace_size_ != 0) {
|
||||
workspace_addr = GetDeviceAddress<T>(workspace, 1);
|
||||
workspace_addr = GetDeviceAddress<T>(workspace, 3);
|
||||
}
|
||||
|
||||
size_t N = input_shape_[0];
|
||||
|
@ -92,14 +94,14 @@ class InstanceNormGradGpuKernel : public GpuKernel {
|
|||
const float alpha_param_diff = 1;
|
||||
const float beta_param_diff = 0;
|
||||
float *reserve_addr = nullptr;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationBackwardEx(
|
||||
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x,
|
||||
y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, dgamma, dbeta,
|
||||
epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0),
|
||||
"Kernel launch failed");
|
||||
ComputeMean(N, C, dgamma, dbeta, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnBatchNormalizationBackwardEx(
|
||||
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff,
|
||||
&beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx,
|
||||
scale_bias_diff_desc_, ws_gamma, beta, ws_dgamma, ws_dbeta, epsilon_, save_mean,
|
||||
save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0),
|
||||
"Kernel launch failed");
|
||||
ComputeMean(N, C, dgamma, dbeta, ws_dgamma, ws_dbeta, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -164,10 +166,12 @@ class InstanceNormGradGpuKernel : public GpuKernel {
|
|||
input_size_list_.push_back(para_size_);
|
||||
|
||||
output_size_list_.push_back(x_size_);
|
||||
output_size_list_.push_back(para_size_);
|
||||
output_size_list_.push_back(para_size_);
|
||||
output_size_list_.push_back(x_size_);
|
||||
output_size_list_.push_back(x_size_);
|
||||
|
||||
workspace_size_list_.push_back(para_size_); // ws gamma
|
||||
workspace_size_list_.push_back(para_size_); // ws dgamma
|
||||
workspace_size_list_.push_back(para_size_); // ws dbeta
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
}
|
||||
void DestroyResource() noexcept override {
|
||||
|
|
|
@ -31,6 +31,7 @@ from mindspore._extends import cell_attr_register
|
|||
from mindspore.communication.management import get_group_size, get_rank
|
||||
from mindspore.communication import management
|
||||
from mindspore.ops import _selected_ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from ..cell import Cell
|
||||
|
||||
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
|
||||
|
@ -999,7 +1000,7 @@ class InstanceNorm2d(Cell):
|
|||
if not isinstance(val, (Tensor, numbers.Number, str, Initializer)):
|
||||
raise TypeError(f"[{name}]Supported type for arg {key} is [Tensor, numbers.Number, str, Initializer],"
|
||||
f"but got {type(val)}")
|
||||
if isinstance(val, Tensor) and val.dtype is not float:
|
||||
if isinstance(val, Tensor) and val.dtype != mstype.float32:
|
||||
raise TypeError(f"[{name}]The type of arg {key} should be float32, but got {val.dtype}")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue