fix updated gamma beta not same with that of torch

This commit is contained in:
zhouyuanshen 2021-03-16 19:21:17 +08:00
parent 6dab729ea1
commit 206762c935
4 changed files with 26 additions and 21 deletions

View File

@ -61,30 +61,30 @@ void CopyMemDevice2Device(const size_t N, const size_t C, float *gamma_addr, flo
}
__global__ void ComputeMeanKernel(const size_t thread_num, const size_t N, const size_t C,
float *save_mean_addr, float *save_var_addr) {
float *dgamma, float *dbeta, const float *ws_dgamma, const float *ws_dbeta) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) {
size_t cur_addr = pos / C;
size_t cur_local_index = pos % C;
float tmp = 0;
if (cur_addr) {
for (size_t i = 0; i < N; i++) {
tmp += save_var_addr[i * C + cur_local_index];
tmp += ws_dgamma[i * C + cur_local_index];
}
save_var_addr[cur_local_index] = tmp / N;
dgamma[cur_local_index] = tmp;
} else {
for (size_t i = 0; i < N; i++) {
tmp += save_mean_addr[i * C + cur_local_index];
tmp += ws_dbeta[i * C + cur_local_index];
}
save_mean_addr[cur_local_index] = tmp / N;
dbeta[cur_local_index] = tmp;
}
}
return;
}
void ComputeMean(const size_t N, const size_t C,
float *save_mean_addr, float *save_var_addr,
float *dgamma, float *dbeta, const float *ws_dgamma, const float *ws_dbeta,
cudaStream_t cuda_stream) {
size_t thread_num = C * 2;
ComputeMeanKernel<<<GET_BLOCKS(thread_num), GET_THREADS, 0, cuda_stream>>>(
thread_num, N, C, save_mean_addr, save_var_addr);
thread_num, N, C, dgamma, dbeta, ws_dgamma, ws_dbeta);
}

View File

@ -22,6 +22,6 @@ void CopyMemDevice2Device(const size_t N, const size_t C,
float *gamma_addr, float *beta_addr, float *runing_mean_addr, float *runnig_variance_addr,
float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var,
cudaStream_t cuda_stream);
void ComputeMean(const size_t N, const size_t C, float *save_mean_addr, float *save_var_addr,
cudaStream_t cuda_stream);
void ComputeMean(const size_t N, const size_t C, float *dgamma, float *dbeta, const float *ws_dgamma,
const float *ws_dbeta, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_

View File

@ -76,9 +76,11 @@ class InstanceNormGradGpuKernel : public GpuKernel {
T *dz = nullptr;
float *ws_gamma = GetDeviceAddress<float>(workspace, 0);
float *ws_dgamma = GetDeviceAddress<float>(workspace, 1);
float *ws_dbeta = GetDeviceAddress<float>(workspace, 2);
void *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 1);
workspace_addr = GetDeviceAddress<T>(workspace, 3);
}
size_t N = input_shape_[0];
@ -92,14 +94,14 @@ class InstanceNormGradGpuKernel : public GpuKernel {
const float alpha_param_diff = 1;
const float beta_param_diff = 0;
float *reserve_addr = nullptr;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationBackwardEx(
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x,
y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, dgamma, dbeta,
epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
ComputeMean(N, C, dgamma, dbeta, reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnBatchNormalizationBackwardEx(
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff,
&beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx,
scale_bias_diff_desc_, ws_gamma, beta, ws_dgamma, ws_dbeta, epsilon_, save_mean,
save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
ComputeMean(N, C, dgamma, dbeta, ws_dgamma, ws_dbeta, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -164,10 +166,12 @@ class InstanceNormGradGpuKernel : public GpuKernel {
input_size_list_.push_back(para_size_);
output_size_list_.push_back(x_size_);
output_size_list_.push_back(para_size_);
output_size_list_.push_back(para_size_);
output_size_list_.push_back(x_size_);
output_size_list_.push_back(x_size_);
workspace_size_list_.push_back(para_size_); // ws gamma
workspace_size_list_.push_back(para_size_); // ws dgamma
workspace_size_list_.push_back(para_size_); // ws dbeta
workspace_size_list_.push_back(workspace_size_);
}
void DestroyResource() noexcept override {

View File

@ -31,6 +31,7 @@ from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management
from mindspore.ops import _selected_ops
from mindspore.common import dtype as mstype
from ..cell import Cell
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
@ -999,7 +1000,7 @@ class InstanceNorm2d(Cell):
if not isinstance(val, (Tensor, numbers.Number, str, Initializer)):
raise TypeError(f"[{name}]Supported type for arg {key} is [Tensor, numbers.Number, str, Initializer],"
f"but got {type(val)}")
if isinstance(val, Tensor) and val.dtype is not float:
if isinstance(val, Tensor) and val.dtype != mstype.float32:
raise TypeError(f"[{name}]The type of arg {key} should be float32, but got {val.dtype}")