diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_grad_impl.cu new file mode 100644 index 00000000000..d51f11357ba --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_grad_impl.cu @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_grad_impl.cuh" + +const int kWarpSize = 32; +const int kNumWarps = 32; + +__inline__ __device__ float HalfFloatInputConvert(const half val) { return __half2float(val); } +__inline__ __device__ float HalfFloatInputConvert(const float val) { return val; } +__inline__ __device__ void HalfFloatOutputAssign(const float val, float *arr, int idx) { arr[idx] = val; } +__inline__ __device__ void HalfFloatOutputAssign(const float val, half *arr, int idx) { arr[idx] = __float2half(val); } + +template +__global__ void SyncBatchNormGradPre(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, + G *saved_mean, G *saved_invstd, float *dy_sum_local, float *dot_p_local) { + // block level memory + __shared__ float shared_dy[kNumWarps]; + __shared__ float shared_dot_p[kNumWarps]; + int warpId = threadIdx.x / kWarpSize; // threads are arranged in warps of 32 executed together + int laneId = threadIdx.x % kWarpSize; + + int plane = blockIdx.x; // this thread will only function on a single plane + int plane_size = N * H * W; + float mean = static_cast(saved_mean[plane]); + + if (threadIdx.x < kNumWarps) { + shared_dy[threadIdx.x] = static_cast(0); + shared_dot_p[threadIdx.x] = static_cast(0); + } + + __syncthreads(); // ensure all 0 init complete across all values + + float dy_sum = static_cast(0); + float dot_p = static_cast(0); + + // individual thread level reduction + for (int x = threadIdx.x; x < plane_size; x += blockDim.x) { + int index = (x / (H * W) * C * H * W) + (plane * H * W) + (x % (H * W)); + float input_value = HalfFloatInputConvert(x_input[index]); + float dy_value = HalfFloatInputConvert(dy[index]); + dy_sum += dy_value; + dot_p += (input_value - mean) * dy_value; + } + __syncthreads(); + // warp reduce all values in every value to a single value + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + float other_dy_sum = __shfl_down_sync(0xffffffff, dy_sum, offset); + float other_dot_p = __shfl_down_sync(0xffffffff, dot_p, offset); + dy_sum += other_dy_sum; + dot_p += other_dot_p; + } + __syncwarp(); + if (laneId == 0) { + shared_dy[warpId] = dy_sum; + shared_dot_p[warpId] = dot_p; + // one value per warp now + } + __syncthreads(); + if (warpId == 0) { + dy_sum = shared_dy[laneId]; + dot_p = shared_dot_p[laneId]; + __syncwarp(); + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + float other_dy = __shfl_down_sync(0xffffffff, dy_sum, offset); + float other_dot_p = __shfl_down_sync(0xffffffff, dot_p, offset); + dy_sum += other_dy; + dot_p += other_dot_p; + } + __syncwarp(); + } + if (threadIdx.x == 0) { + dy_sum_local[plane] = dy_sum; + dot_p_local[plane] = dot_p; + } + return; +} + +template +__global__ void SyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, T *dx, + G *saved_mean, G *saved_invstd, float *dy_sum_red, float *dot_p_red, S *scale, + S *dscale, S *dbias, float epsilon) { + int size = N * C * H * W; + int plane_size = N * H * W; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int block_num = (pos / W) / H; // which of N * C blocks + int plane = block_num % C; + float mean = HalfFloatInputConvert(saved_mean[plane]); + float invstd = HalfFloatInputConvert(saved_invstd[plane]); + float scale_value = HalfFloatInputConvert(scale[plane]); + float div_factor = HalfFloatInputConvert(1) / plane_size; + float dy_sum_plane = dy_sum_red[plane]; + float dot_p_plane = dot_p_red[plane]; + float grad_mean = dy_sum_plane * div_factor; + float proj_scale = dot_p_plane * div_factor * invstd * invstd; + float grad_scale = invstd * scale_value; + float inp = HalfFloatInputConvert(x_input[pos]); + float proj = (inp - mean) * proj_scale; + HalfFloatOutputAssign((HalfFloatInputConvert(dy[pos]) - proj - grad_mean) * grad_scale, dx, pos); + } +} + +template +__global__ void SyncBatchNormGradPostScaleBias(size_t C, G *saved_invstd, float *dy_sum_red, float *dot_p_red, + S *dscale, S *dbias) { + for (size_t plane = blockIdx.x * blockDim.x + threadIdx.x; plane < C; plane += blockDim.x * gridDim.x) { + float invstd = HalfFloatInputConvert(saved_invstd[plane]); + float dy_sum_plane = dy_sum_red[plane]; + float dot_p_plane = dot_p_red[plane]; + dscale[plane] = static_cast(dot_p_plane * invstd); + dbias[plane] = static_cast(dy_sum_plane); + } +} + +template +void CalSyncBatchNormGradPre(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, G *saved_mean, + G *saved_invstd, float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream) { + SyncBatchNormGradPre<<>>(N, C, H, W, x_input, dy, saved_mean, saved_invstd, + dy_sum_local, dot_p_local); + return; +} +template +void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, T *dx, + G *saved_mean, G *saved_invstd, float *dy_sum_red, float *dot_p_red, S *scale, S *dscale, + S *dbias, float epsilon, cudaStream_t cuda_stream) { + SyncBatchNormGradPost<<>>(N, C, H, W, x_input, dy, dx, saved_mean, saved_invstd, + dy_sum_red, dot_p_red, scale, dscale, dbias, epsilon); + SyncBatchNormGradPostScaleBias<<(GET_THREADS)), 0, cuda_stream>>>( + C, saved_invstd, dy_sum_red, dot_p_red, dscale, dbias); +} +// PRE FUNCTION +template void CalSyncBatchNormGradPre(size_t N, size_t C, size_t H, size_t W, const float *x_input, + const float *dy, float *saved_mean, float *saved_invstd, + float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream); +template void CalSyncBatchNormGradPre(size_t N, size_t C, size_t H, size_t W, const float *x_input, + const float *dy, half *saved_mean, half *saved_invstd, + float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream); +template void CalSyncBatchNormGradPre(size_t N, size_t C, size_t H, size_t W, const half *x_input, + const half *dy, float *saved_mean, float *saved_invstd, + float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream); +template void CalSyncBatchNormGradPre(size_t N, size_t C, size_t H, size_t W, const half *x_input, + const half *dy, half *saved_mean, half *saved_invstd, + float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream); +// POST FUNCTION +template void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, + const float *x_input, const float *dy, float *dx, + float *saved_mean, float *saved_invstd, float *dy_sum_red, + float *dot_p_red, float *scale, float *dscale, float *dbias, + float epsilon, cudaStream_t cuda_stream); +template void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const half *x_input, + const half *dy, half *dx, float *saved_mean, + float *saved_invstd, float *dy_sum_red, float *dot_p_red, + float *scale, float *dscale, float *dbias, float epsilon, + cudaStream_t cuda_stream); +template void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const float *x_input, + const float *dy, float *dx, float *saved_mean, + float *saved_invstd, float *dy_sum_red, float *dot_p_red, + half *scale, half *dscale, half *dbias, float epsilon, + cudaStream_t cuda_stream); +template void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const half *x_input, + const half *dy, half *dx, float *saved_mean, + float *saved_invstd, float *dy_sum_red, float *dot_p_red, + half *scale, half *dscale, half *dbias, float epsilon, + cudaStream_t cuda_stream); +template void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const float *x_input, + const float *dy, float *dx, half *saved_mean, + half *saved_invstd, float *dy_sum_red, float *dot_p_red, + float *scale, float *dscale, float *dbias, float epsilon, + cudaStream_t cuda_stream); +template void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const half *x_input, + const half *dy, half *dx, half *saved_mean, + half *saved_invstd, float *dy_sum_red, float *dot_p_red, + float *scale, float *dscale, float *dbias, float epsilon, + cudaStream_t cuda_stream); +template void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const float *x_input, + const float *dy, float *dx, half *saved_mean, + half *saved_invstd, float *dy_sum_red, float *dot_p_red, + half *scale, half *dscale, half *dbias, float epsilon, + cudaStream_t cuda_stream); +template void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const half *x_input, + const half *dy, half *dx, half *saved_mean, half *saved_invstd, + float *dy_sum_red, float *dot_p_red, half *scale, half *dscale, + half *dbias, float epsilon, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_grad_impl.cuh new file mode 100644 index 00000000000..78eaa2c1b5f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_grad_impl.cuh @@ -0,0 +1,27 @@ +// /** +// * Copyright 2021 Huawei Technologies Co., Ltd +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_GRAD_IMPL_CUH +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_GRAD_IMPL_CUH +#include "runtime/device/gpu/cuda_common.h" +template +void CalSyncBatchNormGradPre(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, G *saved_mean, + G *invstd_saved, float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream); +template +void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, T *dx, + G *saved_mean, G *invstd_saved, float *dy_sum_red, float *dot_p_red, S *scale, S *dscale, + S *dbias, float epsilon, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_GRAD_IMPL_CUH diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_impl.cu new file mode 100644 index 00000000000..cd15c37105b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_impl.cu @@ -0,0 +1,248 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_impl.cuh" + +const int kWarpSize = 32; +const int kNumWarps = 32; + +__inline__ __device__ float HalfFloatInputConvert(const half val) { return __half2float(val); } +__inline__ __device__ float HalfFloatInputConvert(const float val) { return val; } +__inline__ __device__ void HalfFloatOutputAssign(const float val, float *arr, int idx) { arr[idx] = val; } +__inline__ __device__ void HalfFloatOutputAssign(const float val, half *arr, int idx) { arr[idx] = __float2half(val); } + +template +__global__ void SyncBatchNormPre(size_t N, size_t C, size_t H, size_t W, const T *input, int *output_n, + float *output_mean, float *output_invstd, float epsilon) { + // block level memory + __shared__ float shared_mean[kNumWarps]; + __shared__ float shared_var[kNumWarps]; + __shared__ int shared_n[kNumWarps]; + + int warpId = threadIdx.x / kWarpSize; // threads execute in warps of 32 + int laneId = threadIdx.x % kWarpSize; + int plane = blockIdx.x; + int plane_size = N * H * W; + if (threadIdx.x < kNumWarps) { + shared_mean[threadIdx.x] = static_cast(0); + shared_var[threadIdx.x] = static_cast(0); + } + // ensure all 0 init complete across all values + __syncthreads(); + + // agg values + float avg = 0; + float var_n = 0; + int n = 0; + + // individual thread level reduction + for (int x = threadIdx.x; x < plane_size; x += blockDim.x) { + int index = (x / (H * W) * C * H * W) + (plane * H * W) + (x % (H * W)); + float input_val = HalfFloatInputConvert(input[index]); + float d1 = input_val - avg; + n++; + avg = avg + (d1 / n); + var_n = var_n + (d1 * (input_val - avg)); + } + __syncthreads(); + + // Reduce every warp to a single value + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + float other_avg = __shfl_down_sync(0xffffffff, avg, offset); + float other_n = __shfl_down_sync(0xffffffff, n, offset); + float div_factor = 1.0 / fmaxf(1.0, n + other_n); + float other_var_n = __shfl_down_sync(0xffffffff, var_n, offset); + var_n += other_var_n + (avg - other_avg) * (avg - other_avg) * n * other_n * div_factor; + avg = (n * avg + other_n * other_avg) * div_factor; + n += other_n; + } + __syncwarp(); + if (laneId == 0) { + // lane 0 for every warp moves value + shared_n[warpId] = n; + shared_mean[warpId] = avg; + shared_var[warpId] = var_n; + // now one value per warp + } + // second reduction to reduce all warps into a single value + __syncthreads(); + if (warpId == 0) { + n = shared_n[laneId]; + avg = shared_mean[laneId]; + var_n = shared_var[laneId]; + __syncwarp(); + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + int other_n = __shfl_down_sync(0xffffffff, n, offset); + float other_avg = __shfl_down_sync(0xffffffff, avg, offset); + float div_factor = 1.0 / fmaxf(1.0, n + other_n); + float other_var_n = __shfl_down_sync(0xffffffff, var_n, offset); + var_n += other_var_n + (avg - other_avg) * (avg - other_avg) * n * other_n * div_factor; + avg = (n * avg + other_n * other_avg) * div_factor; + n += other_n; + } + __syncwarp(); + } + if (threadIdx.x == 0) { + output_n[plane] = n; + output_mean[plane] = avg; + output_invstd[plane] = static_cast(1) / sqrt((var_n / plane_size) + epsilon); + } + return; +} + +template +__global__ void SyncBatchNormGather(size_t N, size_t C, size_t H, size_t W, int *counts_global, float *means_global, + float *invstds_global, int *counts_local, float *means_local, float *invstds_local, + T *running_mean_output, T *running_var_output, G *running_mean_input, + G *running_var_input, float epsilon, float momentum, size_t group_rank, + size_t group_size) { + int feature_size = C; + int world_size = group_size; + for (size_t C_ix = blockIdx.x * blockDim.x + threadIdx.x; C_ix < C; C_ix += blockDim.x * gridDim.x) { + float avg = 0; + float var_n = 0; + float n = 0; + for (int N_ix = 0; N_ix < world_size; N_ix++) { + int count = counts_global[N_ix * feature_size + C_ix]; + float mean_ = means_global[N_ix * feature_size + C_ix]; + float std = static_cast(1) / invstds_global[N_ix * feature_size + C_ix]; + float var_n_ = (std * std - epsilon) * count; + float div_factor = 1.0 / fmaxf(1.0, n + count); + var_n += var_n_ + (avg - mean_) * (avg - mean_) * n * count * div_factor; + avg = n * div_factor * avg + count * div_factor * mean_; + n += count; + } + means_local[C_ix] = avg; + invstds_local[C_ix] = static_cast(1) / sqrt((var_n / n) + epsilon); + HalfFloatOutputAssign(((1 - momentum) * HalfFloatInputConvert(running_mean_input[C_ix]) + momentum * avg), + running_mean_output, C_ix); + float unbiasedVar = 0.0; + if (n != 0) { // not strictly required since pipeline does not allow empty inputs + unbiasedVar = var_n / n; + } + HalfFloatOutputAssign(((1 - momentum) * HalfFloatInputConvert(running_var_input[C_ix]) + momentum * unbiasedVar), + running_var_output, C_ix); + } + return; +} + +template +__global__ void SyncBatchNormPost(size_t N, size_t C, size_t H, size_t W, const T *input, T *output, float *means_local, + float *invstds_local, S *scale, S *bias, float epsilon) { + int size = N * C * H * W; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int block_num = (pos / W) / H; // which of N * C blocks + int plane = block_num % C; + float scale_plane = HalfFloatInputConvert(scale[plane]); + float bias_plane = HalfFloatInputConvert(bias[plane]); + float mean_plane = means_local[plane]; + float invstd_plane = invstds_local[plane]; + float input_val = HalfFloatInputConvert(input[pos]); + HalfFloatOutputAssign(scale_plane * (input_val - mean_plane) * invstd_plane + bias_plane, output, pos); + } + return; +} + +template +__global__ void SyncBatchNormPostBiasScale(size_t C, S *scale, S *bias, S *output_scale, S *output_bias) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < C; pos += blockDim.x * gridDim.x) { + output_bias[pos] = bias[pos]; + output_scale[pos] = scale[pos]; + } + return; +} + +template +void CalSyncBatchNormPre(size_t N, size_t C, size_t H, size_t W, const T *input, int *output_n, float *output_mean, + float *output_var, float epsilon, cudaStream_t cuda_stream) { + SyncBatchNormPre<<>>(N, C, H, W, input, output_n, output_mean, output_var, epsilon); + return; +} + +template +void CalSyncBatchNormGather(size_t N, size_t C, size_t H, size_t W, int *counts_global, float *means_global, + float *invstds_global, int *counts_local, float *means_local, float *invstds_local, + T *running_mean_output, T *running_var_output, G *running_mean_input, G *running_var_input, + float epsilon, float momentum, size_t group_rank, size_t group_size, + cudaStream_t cuda_stream) { + SyncBatchNormGather<<>>( + N, C, H, W, counts_global, means_global, invstds_global, counts_local, means_local, invstds_local, + running_mean_output, running_var_output, running_mean_input, running_var_input, epsilon, momentum, group_rank, + group_size); + return; +} + +template +void CalSyncBatchNormPost(size_t N, size_t C, size_t H, size_t W, const T *input, T *output, float *means_local, + float *invstds_local, S *scale, S *bias, S *output_scale, S *output_bias, float epsilon, + cudaStream_t cuda_stream) { + SyncBatchNormPost<<>>(N, C, H, W, input, output, means_local, + invstds_local, scale, bias, epsilon); + SyncBatchNormPostBiasScale<<<1, std::min(C, static_cast(GET_THREADS)), 0, cuda_stream>>>( + C, scale, bias, output_scale, output_bias); + return; +} + +template void CalSyncBatchNormPre(size_t N, size_t C, size_t H, size_t W, const float *input, int *output_n, + float *output_mean, float *output_var, float epsilon, + cudaStream_t cuda_stream); +template void CalSyncBatchNormPre(size_t N, size_t C, size_t H, size_t W, const half *input, int *output_n, + float *output_mean, float *output_var, float epsilon, cudaStream_t cuda_stream); + +template void CalSyncBatchNormGather(size_t N_, size_t C_, size_t H_, size_t W_, int *counts_global, + float *means_global, float *invstds_global, int *counts_local, + float *means_local, float *invstds_local, float *running_mean_output, + float *running_var_output, float *running_mean_input, + float *running_var_input, float epsilon, float momentum, + size_t group_rank, size_t group_size, cudaStream_t cuda_stream); +template void CalSyncBatchNormGather(size_t N_, size_t C_, size_t H_, size_t W_, int *counts_global, + float *means_global, float *invstds_global, int *counts_local, + float *means_local, float *invstds_local, float *running_mean_output, + float *running_var_output, half *running_mean_input, + half *running_var_input, float epsilon, float momentum, + size_t group_rank, size_t group_size, cudaStream_t cuda_stream); +template void CalSyncBatchNormGather(size_t N_, size_t C_, size_t H_, size_t W_, int *counts_global, + float *means_global, float *invstds_global, int *counts_local, + float *means_local, float *invstds_local, half *running_mean_output, + half *running_var_output, float *running_mean_input, + float *running_var_input, float epsilon, float momentum, + size_t group_rank, size_t group_size, cudaStream_t cuda_stream); +template void CalSyncBatchNormGather(size_t N_, size_t C_, size_t H_, size_t W_, int *counts_global, + float *means_global, float *invstds_global, int *counts_local, + float *means_local, float *invstds_local, half *running_mean_output, + half *running_var_output, half *running_mean_input, + half *running_var_input, float epsilon, float momentum, + size_t group_rank, size_t group_size, cudaStream_t cuda_stream); + +template void CalSyncBatchNormPost(size_t N, size_t C, size_t H, size_t W, const float *input, + float *output, float *means_local, float *invstds_local, float *scale, + float *bias, float *output_scale, float *output_bias, float epsilon, + cudaStream_t cuda_stream); +template void CalSyncBatchNormPost(size_t N, size_t C, size_t H, size_t W, const float *input, + float *output, float *means_local, float *invstds_local, half *scale, + half *bias, half *output_scale, half *output_bias, float epsilon, + cudaStream_t cuda_stream); +template void CalSyncBatchNormPost(size_t N, size_t C, size_t H, size_t W, const half *input, half *output, + float *means_local, float *invstds_local, float *scale, float *bias, + float *output_scale, float *output_bias, float epsilon, + cudaStream_t cuda_stream); +template void CalSyncBatchNormPost(size_t N, size_t C, size_t H, size_t W, const half *input, half *output, + float *means_local, float *invstds_local, half *scale, half *bias, + half *output_scale, half *output_bias, float epsilon, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_impl.cuh new file mode 100644 index 00000000000..708eb8f2f9a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_impl.cuh @@ -0,0 +1,33 @@ +// /** +// * Copyright 2021 Huawei Technologies Co., Ltd +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_IMPL_CUH +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_IMPL_CUH +#include "runtime/device/gpu/cuda_common.h" +template +void CalSyncBatchNormPre(size_t N, size_t C, size_t H, size_t W, const T *input, int *output_n, float *means_local, + float *invstds_local, float epsilon, cudaStream_t cuda_stream); +template +void CalSyncBatchNormGather(size_t N, size_t C, size_t H, size_t W, int *counts_global, float *means_global, + float *invstds_global, int *counts_local, float *means_local, float *invstds_local, + T *running_mean_output, T *running_var_output, G *running_mean_input, G *running_var_input, + float epsilon, float momentum, size_t group_rank, size_t group_size, + cudaStream_t cuda_stream); +template +void CalSyncBatchNormPost(size_t N, size_t C, size_t H, size_t W, const T *input, T *output, float *means_local, + float *invstds_local, S *scale, S *bias, S *output_scale, S *output_bias, float epsilon, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_IMPL_CUH diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.cc new file mode 100644 index 00000000000..7dce5bb4510 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SyncBatchNormGpuKernel, float, float, float) +MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SyncBatchNormGpuKernel, half, float, float) +MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SyncBatchNormGpuKernel, float, half, float) +MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SyncBatchNormGpuKernel, half, half, float) +MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SyncBatchNormGpuKernel, float, float, half) +MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SyncBatchNormGpuKernel, half, float, half) +MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SyncBatchNormGpuKernel, float, half, half) +MS_REG_GPU_KERNEL_THREE(SyncBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SyncBatchNormGpuKernel, half, half, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.h new file mode 100644 index 00000000000..bfbd84e3577 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.h @@ -0,0 +1,246 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SyncBatchNormGpuKernel : public NcclGpuKernel { + public: + SyncBatchNormGpuKernel() { ResetResource(); } + ~SyncBatchNormGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *x = GetDeviceAddress(inputs, 0); + S *scale = GetDeviceAddress(inputs, 1); + S *bias = GetDeviceAddress(inputs, 2); + G *running_mean_input = GetDeviceAddress(inputs, 3); + G *running_variance_input = GetDeviceAddress(inputs, 4); + + float *means_local = GetDeviceAddress(workspace, 0); // per device + float *invstds_local = GetDeviceAddress(workspace, 1); + int *counts_local = GetDeviceAddress(workspace, 2); + int *counts_global = GetDeviceAddress(workspace, 3); // gathered values from all devices + float *means_global = GetDeviceAddress(workspace, 4); + float *invstds_global = GetDeviceAddress(workspace, 5); + + T *y = GetDeviceAddress(outputs, 0); + S *output_scale = GetDeviceAddress(outputs, 1); + S *output_bias = GetDeviceAddress(outputs, 2); + T *output_running_mean = GetDeviceAddress(outputs, 3); + T *output_running_variance = GetDeviceAddress(outputs, 4); + + // aggregate means and invstd on each device locally + CalSyncBatchNormPre(N_, C_, H_, W_, x, counts_local, means_local, invstds_local, epsilon_, + reinterpret_cast(stream_ptr)); + // gather values from all devices together + LaunchAllGather(means_local, means_global, stream_ptr); + LaunchAllGather(invstds_local, invstds_global, stream_ptr); + LaunchAllGather(counts_local, counts_global, stream_ptr); + // reducing gathered values on each device and deal with running means and variance + CalSyncBatchNormGather(N_, C_, H_, W_, counts_global, means_global, invstds_global, counts_local, means_local, + invstds_local, output_running_mean, output_running_variance, running_mean_input, + running_variance_input, epsilon_, momentum_, group_rank_, group_size_, + reinterpret_cast(stream_ptr)); + CalSyncBatchNormPost(N_, C_, H_, W_, x, y, means_local, invstds_local, scale, bias, output_scale, output_bias, + epsilon_, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto root_rank = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrRootRank); + if (root_rank) { + root_ = static_cast(GetValue(root_rank)); + } + nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)); + group_name_ = GetAttr(kernel_node, kAttrGroup); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SyncBatchNorm needs 5 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 5) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but SyncBatchNorm needs 5 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (CHECK_NULL_INPUT(input_shape)) { + MS_LOG(WARNING) << "SyncBatchNorm input is null"; + InitSizeLists(); + return true; + } + auto input_shape_dims = input_shape.size(); + if (input_shape_dims != 4 && input_shape_dims != 2) { + MS_LOG(EXCEPTION) << "Tensor shape is " << input_shape.size() + << ", SyncBatchNormGpuKernel input should be 2D or 4D"; + } + input_size_ = 1; + for (auto dim : input_shape) { + input_size_ *= dim; + } + epsilon_ = GetAttr(kernel_node, "epsilon"); + momentum_ = GetAttr(kernel_node, "momentum"); + output_size_ = input_size_; + output_size_ = output_size_ * sizeof(T); + input_size_ = input_size_ * sizeof(T); + param_count_ = input_shape[1]; // C is number of features + param_size_S_ = param_count_ * sizeof(S); // will be second/third template + param_size_G_input_ = param_count_ * sizeof(G); + param_size_G_output_ = param_count_ * sizeof(T); + workspace_size_ = param_count_; // specific size computed in InitSizeLists() + N_ = input_shape[0]; + C_ = input_shape[1]; + if (input_shape_dims == 2) { + // NC -> N,C,1,1 transform input dims + H_ = 1; + W_ = 1; + } else { + H_ = input_shape[2]; + W_ = input_shape[3]; + } + // MULTI DEVICE SPECIFICS + group_name_ = GetAttr(kernel_node, kAttrGroup); + MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_; + auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id"); + if (comm_stream_attr) { + comm_stream_ = reinterpret_cast(GetValue(comm_stream_attr)); + MS_EXCEPTION_IF_NULL(comm_stream_); + } + collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + // Get group size + auto get_group_size_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "GetGroupRanks")); + MS_EXCEPTION_IF_NULL(get_group_size_funcptr); + std::vector group_ranks = (*get_group_size_funcptr)(group_name_); + group_size_ = group_ranks.size(); + // // Get device rank ID in group + using GetLocalRankId = device::gpu::GetLocalRankId; + auto get_local_rank_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "local_rank_id")); + MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); + group_rank_ = IntToUint((*get_local_rank_funcptr)()); + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + input_size_ = 0; + output_size_ = 0; + workspace_size_ = 0; + momentum_ = 0; + epsilon_ = 10e-5; + param_size_S_ = 0; + param_size_G_input_ = 0; + param_size_G_output_ = 0; + param_count_ = 0; + N_ = 0; + C_ = 0; + H_ = 0; + W_ = 0; + root_ = 0; + collective_handle_ = nullptr; + comm_stream_ = nullptr; + nccl_reduce_type_ = ncclSum; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); // input x + input_size_list_.push_back(param_size_S_); // scale + input_size_list_.push_back(param_size_S_); // bias + input_size_list_.push_back(param_size_G_input_); // running mean + input_size_list_.push_back(param_size_G_input_); // running variance + output_size_list_.push_back(output_size_); // output + output_size_list_.push_back(param_size_S_); // save scale + output_size_list_.push_back(param_size_S_); // reserve space + output_size_list_.push_back(param_size_G_output_); // save mean + output_size_list_.push_back(param_size_G_output_); // save variance + // local mean/variance data - per device + workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // mean_local + workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // invstd_local + workspace_size_list_.push_back(workspace_size_ * sizeof(int)); // count_local + // global mean/variance data - for all devices + workspace_size_list_.push_back(workspace_size_ * sizeof(int) * group_size_); // gathered mean + workspace_size_list_.push_back(workspace_size_ * sizeof(float) * group_size_); // gathered invstd + workspace_size_list_.push_back(workspace_size_ * sizeof(float) * group_size_); // gathered count + } + + private: + // GetTypeID functions return the correct typeID for input template + // Allow for a single templated LaunchAllGather function + mindspore::TypeId GetTypeID(float *input) { return kNumberTypeFloat32; } + mindspore::TypeId GetTypeID(int *input) { return kNumberTypeInt32; } + template + void LaunchAllGather(gather_type *input_addr, gather_type *output_addr, void *stream_ptr) { + cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); + auto all_gather_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); + MS_EXCEPTION_IF_NULL(all_gather_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, + (*all_gather_funcptr)(input_addr, output_addr, C_, nccl_dtype(GetTypeID(input_addr)), stream, group_name_), + "ncclAllGather failed"); + } + + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + float momentum_; + float epsilon_; + size_t param_size_S_; + size_t param_size_G_input_; + size_t param_size_G_output_; + size_t param_count_; + size_t N_; + size_t C_; + size_t H_; + size_t W_; + size_t group_size_; + size_t group_rank_; + ncclRedOp_t nccl_reduce_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + // NCCL + string group_name_; + int root_; + const void *collective_handle_; + cudaStream_t comm_stream_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.cc new file mode 100644 index 00000000000..ebfbcc7aeed --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SyncBatchNormGradGpuKernel, float, float, float) +MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SyncBatchNormGradGpuKernel, half, float, float) +MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SyncBatchNormGradGpuKernel, float, half, float) +MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SyncBatchNormGradGpuKernel, half, half, float) +MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SyncBatchNormGradGpuKernel, float, float, half) +MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SyncBatchNormGradGpuKernel, half, float, half) +MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SyncBatchNormGradGpuKernel, float, half, half) +MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SyncBatchNormGradGpuKernel, half, half, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.h new file mode 100644 index 00000000000..d185b208e76 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.h @@ -0,0 +1,209 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SyncBatchNormGradGpuKernel : public NcclGpuKernel { + public: + SyncBatchNormGradGpuKernel() { ResetResource(); } + ~SyncBatchNormGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *dy = GetDeviceAddress(inputs, 0); + T *x_input = GetDeviceAddress(inputs, 1); + S *scale = GetDeviceAddress(inputs, 2); + G *saved_mean = GetDeviceAddress(inputs, 3); + G *saved_variance = GetDeviceAddress(inputs, 4); + float *dy_sum_local = GetDeviceAddress(workspace, 0); + float *dot_p_local = GetDeviceAddress(workspace, 1); + float *dy_sum_red = GetDeviceAddress(workspace, 2); + float *dot_p_red = GetDeviceAddress(workspace, 3); + T *dx = GetDeviceAddress(outputs, 0); + S *dscale = GetDeviceAddress(outputs, 1); + S *dbias = GetDeviceAddress(outputs, 2); + // aggregate interim values on each device locally + CalSyncBatchNormGradPre(N_, C_, H_, W_, x_input, dy, saved_mean, saved_variance, dy_sum_local, dot_p_local, + reinterpret_cast(stream_ptr)); + // reduce values across devices + LaunchAllReduce(dy_sum_local, dy_sum_red, stream_ptr); + LaunchAllReduce(dot_p_local, dot_p_red, stream_ptr); + // Aggregate and compute output + CalSyncBatchNormGradPost(N_, C_, H_, W_, x_input, dy, dx, saved_mean, saved_variance, dy_sum_red, dot_p_red, scale, + dscale, dbias, epsilon_, reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto root_rank = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrRootRank); + if (root_rank) { + root_ = static_cast(GetValue(root_rank)); + } + nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)); + group_name_ = GetAttr(kernel_node, kAttrGroup); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SyncBatchNormGrad needs 5 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 3) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but SyncBatchNormGrad needs 5 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (CHECK_NULL_INPUT(input_shape)) { + MS_LOG(WARNING) << "SyncBatchNormGrad input is null"; + InitSizeLists(); + return true; + } + auto input_shape_dims = input_shape.size(); + if (input_shape_dims != 4 && input_shape_dims != 2) { + MS_LOG(EXCEPTION) << "Tensor shape is " << input_shape.size() + << ", SyncBatchNormGpuGrad input should be 2D or 4D"; + } + input_size_ = 1; + for (auto dim : input_shape) { + input_size_ *= dim; + } + output_size_ = input_size_; + output_size_ = output_size_ * sizeof(T); + input_size_ = input_size_ * sizeof(T); + param_count_ = input_shape[1]; + param_size_S_ = param_count_ * sizeof(S); + param_size_G_ = param_count_ * sizeof(G); + N_ = input_shape[0]; + C_ = input_shape[1]; + if (input_shape_dims == 2) { // N,C,1,1 transform input + H_ = 1; + W_ = 1; + } else { + H_ = input_shape[2]; + W_ = input_shape[3]; + } + workspace_size_ = C_; + epsilon_ = GetAttr(kernel_node, "epsilon"); + // MULTIDEVICE SPECIFICS + group_name_ = GetAttr(kernel_node, kAttrGroup); + MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_; + auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id"); + if (comm_stream_attr) { + comm_stream_ = reinterpret_cast(GetValue(comm_stream_attr)); + MS_EXCEPTION_IF_NULL(comm_stream_); + } + collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + // Get group size + auto get_group_size_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "GetGroupRanks")); + MS_EXCEPTION_IF_NULL(get_group_size_funcptr); + std::vector group_ranks = (*get_group_size_funcptr)(group_name_); + device_count_ = group_ranks.size(); + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + input_size_ = 0; + output_size_ = 0; + workspace_size_ = 0; + epsilon_ = 10e-5; // default + param_size_S_ = 0; + param_size_G_ = 0; + param_count_ = 0; + N_ = 0; + C_ = 0; + H_ = 0; + W_ = 0; + root_ = 0; + collective_handle_ = nullptr; + comm_stream_ = nullptr; + nccl_reduce_type_ = ncclSum; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); // dy + input_size_list_.push_back(input_size_); // x + input_size_list_.push_back(param_size_S_); // scale + input_size_list_.push_back(param_size_G_); // saved_mean + input_size_list_.push_back(param_size_G_); // saved_variance + output_size_list_.push_back(output_size_); // dx + output_size_list_.push_back(param_size_S_); // dscale + output_size_list_.push_back(param_size_S_); // dbias + workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // sum_dy + workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // sum_dy_xmu + workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // reduced sum_dy + workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // reduced sum_dy_xmu + } + + private: + template + void LaunchAllReduce(reduce_type *input_addr, reduce_type *output_addr, void *stream_ptr) { + cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); + auto all_reduce_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); + MS_EXCEPTION_IF_NULL(all_reduce_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, + (*all_reduce_funcptr)(input_addr, output_addr, C_, nccl_dtype(kNumberTypeFloat32), + nccl_reduce_type_, stream, group_name_), + "ncclAllReduce - SyncBatchNormGrad - CUDA failed"); + } + + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + float epsilon_; + size_t param_size_S_; + size_t param_size_G_; + size_t param_count_; + size_t N_; + size_t C_; + size_t H_; + size_t W_; + size_t device_count_; + ncclRedOp_t nccl_reduce_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + // NCCL + string group_name_; + int root_; + const void *collective_handle_; + cudaStream_t comm_stream_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index e4cc61a5784..c358c9508a1 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -108,9 +108,13 @@ class _BatchNorm(Cell): self.is_global = True self.group_device_num = self.rank_size self.device_list = [i for i in range(0, self.rank_size)] - if SYNC_BN_GROUP_NAME == "": - SYNC_BN_GROUP_NAME = "sync_bn_group0" - management.create_group(SYNC_BN_GROUP_NAME, self.device_list) + if context.get_context("device_target") == "Ascend": + if SYNC_BN_GROUP_NAME == "": + SYNC_BN_GROUP_NAME = "sync_bn_group0" + management.create_group(SYNC_BN_GROUP_NAME, self.device_list) + elif context.get_context("device_target") == "GPU": + if SYNC_BN_GROUP_NAME == "": + SYNC_BN_GROUP_NAME = "nccl_world_group" self.shape = P.Shape() self.reduce_mean = P.ReduceMean(keep_dims=True) diff --git a/tests/st/nccl/test_nccl_sync_batch_norm_op.py b/tests/st/nccl/test_nccl_sync_batch_norm_op.py new file mode 100644 index 00000000000..65d47246e43 --- /dev/null +++ b/tests/st/nccl/test_nccl_sync_batch_norm_op.py @@ -0,0 +1,134 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.communication.management import init +from mindspore.ops import composite as C + +# define target and input values here +x_fwd_input = np.array([[ + [[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]], + [[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32) +expect_output_fwd = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294], + [-0.1471, 0.7706, 1.6882, 2.6059], + [0.3118, 1.6882, 2.1471, 2.1471], + [0.7706, 0.3118, 2.6059, -0.1471]], + [[0.9119, 1.8518, 1.3819, -0.0281], + [-0.0281, 0.9119, 1.3819, 1.8518], + [2.7918, 0.4419, -0.4981, 0.9119], + [1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32) +grad_back = np.array([[[[1, 2, 7, 1], [4, 2, 1, 3], [1, 6, 5, 2], [2, 4, 3, 2]], + [[9, 4, 3, 5], [1, 3, 7, 6], [5, 7, 9, 9], [1, 4, 6, 8]]]]).astype(np.float32) +expect_output_back = np.array([[[[-0.69126546, -0.32903028, 1.9651246, -0.88445705], + [0.6369296, -0.37732816, -0.93275493, -0.11168876], + [-0.7878612, 1.3614, 0.8542711, -0.52222186], + [-0.37732816, 0.5886317, -0.11168876, -0.28073236]], + [[1.6447213, -0.38968924, -1.0174079, -0.55067265], + [-2.4305856, -1.1751484, 0.86250514, 0.5502673], + [0.39576983, 0.5470243, 1.1715001, 1.6447213], + [-1.7996241, -0.7051701, 0.7080077, 0.5437813]]]]).astype(np.float32) + +class Net(nn.Cell): + def __init__(self, c): + super(Net, self).__init__() + self.num_features = c + self.eps = 1e-5 + self.momentum = 1 + self.mode = True + self.affine = True + self.sync_bn_op = nn.SyncBatchNorm(num_features=self.num_features, + eps=self.eps, + momentum=self.momentum, + affine=self.affine, + gamma_init='ones', + beta_init='ones', + moving_mean_init='ones', + moving_var_init='ones', + use_batch_statistics=True, + process_groups=None) + def construct(self, input_data): + return self.sync_bn_op(input_data) + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, input_data, sens): + gout = self.grad(self.network)(input_data, sens) + return gout + +def test_sync_batch_norm_forward_fp32_graph(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + init() + x = x_fwd_input.copy().astype(np.float32) + expect_output = expect_output_fwd.copy().astype(np.float32) + overall_shape = x.shape + error = np.ones(shape=overall_shape) * 1.0e-4 + net = Net(2) + net.set_train() + output = net(Tensor(x)) + diff = output.asnumpy() - expect_output + assert np.all(diff < error) + assert np.all(-diff < error) + +def test_sync_batch_norm_forward_fp16_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + init() + x = x_fwd_input.copy().astype(np.float16) + expect_output = expect_output_fwd.copy().astype(np.float16) + overall_shape = x.shape + error = np.ones(shape=overall_shape) * 1.0e-3 + net = Net(2) + net.set_train() + output = net(Tensor(x)) + diff = output.asnumpy() - expect_output + assert np.all(diff < error) + assert np.all(-diff < error) + +def test_sync_batch_norm_backwards_fp32_graph(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + init() + x = x_fwd_input.copy().astype(np.float32) + expect_output = expect_output_back.copy().astype(np.float32) + grad = grad_back.copy().astype(np.float32) + overall_shape = x.shape + error = np.ones(shape=overall_shape) * 1.0e-5 + fwd_net = Net(2) + fwd_net.set_train() + bn_grad = Grad(fwd_net) + output = bn_grad(Tensor(x), Tensor(grad)) + diff = output[0].asnumpy() - expect_output + assert np.all(diff < error) + assert np.all(-diff < error) + +def test_sync_batch_norm_backwards_fp16_pynative(): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + init() + x = x_fwd_input.copy().astype(np.float16) + expect_output = expect_output_back.copy().astype(np.float16) + grad = grad_back.copy().astype(np.float16) + overall_shape = x.shape + error = np.ones(shape=overall_shape) * 1.0e-3 + fwd_net = Net(2) + fwd_net.set_train() + bn_grad = Grad(fwd_net) + output = bn_grad(Tensor(x), Tensor(grad)) + diff = output[0].asnumpy() - expect_output + assert np.all(diff < error) + assert np.all(-diff < error) diff --git a/tests/st/nccl/test_nccl_sync_batch_norm_op_all.py b/tests/st/nccl/test_nccl_sync_batch_norm_op_all.py new file mode 100644 index 00000000000..d52d7b55b28 --- /dev/null +++ b/tests/st/nccl/test_nccl_sync_batch_norm_op_all.py @@ -0,0 +1,50 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import os +import pytest + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_single +def test_nccl_sync_batch_norm_1(): + cmd_str = "mpirun -n 4 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_forward_fp32_graph" + return_code = os.system(cmd_str) + assert return_code == 0 + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_single +def test_nccl_sync_batch_norm_2(): + cmd_str = "mpirun -n 4 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_forward_fp16_pynative" + return_code = os.system(cmd_str) + assert return_code == 0 + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_single +def test_nccl_sync_batch_norm_3(): + cmd_str = "mpirun -n 1 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_backwards_fp32_graph" + return_code = os.system(cmd_str) + assert return_code == 0 + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_single +def test_nccl_sync_batch_norm_4(): + cmd_str = "mpirun -n 1 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_backwards_fp16_pynative" + return_code = os.system(cmd_str) + assert return_code == 0