forked from mindspore-Ecosystem/mindspore
first commit
updated files lint fix lint fix 2 file name changes CI run issue fix
This commit is contained in:
parent
5c90bae35c
commit
c34e52c3d6
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_grad_impl.cuh"
|
||||
|
||||
const int kWarpSize = 32;
|
||||
const int kNumWarps = 32;
|
||||
|
||||
__inline__ __device__ float HalfFloatInputConvert(const half val) { return __half2float(val); }
|
||||
__inline__ __device__ float HalfFloatInputConvert(const float val) { return val; }
|
||||
__inline__ __device__ void HalfFloatOutputAssign(const float val, float *arr, int idx) { arr[idx] = val; }
|
||||
__inline__ __device__ void HalfFloatOutputAssign(const float val, half *arr, int idx) { arr[idx] = __float2half(val); }
|
||||
|
||||
template <typename T, typename G>
|
||||
__global__ void SyncBatchNormGradPre(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy,
|
||||
G *saved_mean, G *saved_invstd, float *dy_sum_local, float *dot_p_local) {
|
||||
// block level memory
|
||||
__shared__ float shared_dy[kNumWarps];
|
||||
__shared__ float shared_dot_p[kNumWarps];
|
||||
int warpId = threadIdx.x / kWarpSize; // threads are arranged in warps of 32 executed together
|
||||
int laneId = threadIdx.x % kWarpSize;
|
||||
|
||||
int plane = blockIdx.x; // this thread will only function on a single plane
|
||||
int plane_size = N * H * W;
|
||||
float mean = static_cast<float>(saved_mean[plane]);
|
||||
|
||||
if (threadIdx.x < kNumWarps) {
|
||||
shared_dy[threadIdx.x] = static_cast<float>(0);
|
||||
shared_dot_p[threadIdx.x] = static_cast<float>(0);
|
||||
}
|
||||
|
||||
__syncthreads(); // ensure all 0 init complete across all values
|
||||
|
||||
float dy_sum = static_cast<float>(0);
|
||||
float dot_p = static_cast<float>(0);
|
||||
|
||||
// individual thread level reduction
|
||||
for (int x = threadIdx.x; x < plane_size; x += blockDim.x) {
|
||||
int index = (x / (H * W) * C * H * W) + (plane * H * W) + (x % (H * W));
|
||||
float input_value = HalfFloatInputConvert(x_input[index]);
|
||||
float dy_value = HalfFloatInputConvert(dy[index]);
|
||||
dy_sum += dy_value;
|
||||
dot_p += (input_value - mean) * dy_value;
|
||||
}
|
||||
__syncthreads();
|
||||
// warp reduce all values in every value to a single value
|
||||
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
|
||||
float other_dy_sum = __shfl_down_sync(0xffffffff, dy_sum, offset);
|
||||
float other_dot_p = __shfl_down_sync(0xffffffff, dot_p, offset);
|
||||
dy_sum += other_dy_sum;
|
||||
dot_p += other_dot_p;
|
||||
}
|
||||
__syncwarp();
|
||||
if (laneId == 0) {
|
||||
shared_dy[warpId] = dy_sum;
|
||||
shared_dot_p[warpId] = dot_p;
|
||||
// one value per warp now
|
||||
}
|
||||
__syncthreads();
|
||||
if (warpId == 0) {
|
||||
dy_sum = shared_dy[laneId];
|
||||
dot_p = shared_dot_p[laneId];
|
||||
__syncwarp();
|
||||
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
|
||||
float other_dy = __shfl_down_sync(0xffffffff, dy_sum, offset);
|
||||
float other_dot_p = __shfl_down_sync(0xffffffff, dot_p, offset);
|
||||
dy_sum += other_dy;
|
||||
dot_p += other_dot_p;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
dy_sum_local[plane] = dy_sum;
|
||||
dot_p_local[plane] = dot_p;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename G>
|
||||
__global__ void SyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, T *dx,
|
||||
G *saved_mean, G *saved_invstd, float *dy_sum_red, float *dot_p_red, S *scale,
|
||||
S *dscale, S *dbias, float epsilon) {
|
||||
int size = N * C * H * W;
|
||||
int plane_size = N * H * W;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
int block_num = (pos / W) / H; // which of N * C blocks
|
||||
int plane = block_num % C;
|
||||
float mean = HalfFloatInputConvert(saved_mean[plane]);
|
||||
float invstd = HalfFloatInputConvert(saved_invstd[plane]);
|
||||
float scale_value = HalfFloatInputConvert(scale[plane]);
|
||||
float div_factor = HalfFloatInputConvert(1) / plane_size;
|
||||
float dy_sum_plane = dy_sum_red[plane];
|
||||
float dot_p_plane = dot_p_red[plane];
|
||||
float grad_mean = dy_sum_plane * div_factor;
|
||||
float proj_scale = dot_p_plane * div_factor * invstd * invstd;
|
||||
float grad_scale = invstd * scale_value;
|
||||
float inp = HalfFloatInputConvert(x_input[pos]);
|
||||
float proj = (inp - mean) * proj_scale;
|
||||
HalfFloatOutputAssign((HalfFloatInputConvert(dy[pos]) - proj - grad_mean) * grad_scale, dx, pos);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S, typename G>
|
||||
__global__ void SyncBatchNormGradPostScaleBias(size_t C, G *saved_invstd, float *dy_sum_red, float *dot_p_red,
|
||||
S *dscale, S *dbias) {
|
||||
for (size_t plane = blockIdx.x * blockDim.x + threadIdx.x; plane < C; plane += blockDim.x * gridDim.x) {
|
||||
float invstd = HalfFloatInputConvert(saved_invstd[plane]);
|
||||
float dy_sum_plane = dy_sum_red[plane];
|
||||
float dot_p_plane = dot_p_red[plane];
|
||||
dscale[plane] = static_cast<S>(dot_p_plane * invstd);
|
||||
dbias[plane] = static_cast<S>(dy_sum_plane);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename G>
|
||||
void CalSyncBatchNormGradPre(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, G *saved_mean,
|
||||
G *saved_invstd, float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream) {
|
||||
SyncBatchNormGradPre<<<C, GET_THREADS, 0, cuda_stream>>>(N, C, H, W, x_input, dy, saved_mean, saved_invstd,
|
||||
dy_sum_local, dot_p_local);
|
||||
return;
|
||||
}
|
||||
template <typename T, typename S, typename G>
|
||||
void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, T *dx,
|
||||
G *saved_mean, G *saved_invstd, float *dy_sum_red, float *dot_p_red, S *scale, S *dscale,
|
||||
S *dbias, float epsilon, cudaStream_t cuda_stream) {
|
||||
SyncBatchNormGradPost<<<C, GET_THREADS, 0, cuda_stream>>>(N, C, H, W, x_input, dy, dx, saved_mean, saved_invstd,
|
||||
dy_sum_red, dot_p_red, scale, dscale, dbias, epsilon);
|
||||
SyncBatchNormGradPostScaleBias<<<GET_BLOCKS(C), std::min(C, static_cast<size_t>(GET_THREADS)), 0, cuda_stream>>>(
|
||||
C, saved_invstd, dy_sum_red, dot_p_red, dscale, dbias);
|
||||
}
|
||||
// PRE FUNCTION
|
||||
template void CalSyncBatchNormGradPre<float, float>(size_t N, size_t C, size_t H, size_t W, const float *x_input,
|
||||
const float *dy, float *saved_mean, float *saved_invstd,
|
||||
float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGradPre<float, half>(size_t N, size_t C, size_t H, size_t W, const float *x_input,
|
||||
const float *dy, half *saved_mean, half *saved_invstd,
|
||||
float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGradPre<half, float>(size_t N, size_t C, size_t H, size_t W, const half *x_input,
|
||||
const half *dy, float *saved_mean, float *saved_invstd,
|
||||
float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGradPre<half, half>(size_t N, size_t C, size_t H, size_t W, const half *x_input,
|
||||
const half *dy, half *saved_mean, half *saved_invstd,
|
||||
float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream);
|
||||
// POST FUNCTION
|
||||
template void CalSyncBatchNormGradPost<float, float, float>(size_t N, size_t C, size_t H, size_t W,
|
||||
const float *x_input, const float *dy, float *dx,
|
||||
float *saved_mean, float *saved_invstd, float *dy_sum_red,
|
||||
float *dot_p_red, float *scale, float *dscale, float *dbias,
|
||||
float epsilon, cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGradPost<half, float, float>(size_t N, size_t C, size_t H, size_t W, const half *x_input,
|
||||
const half *dy, half *dx, float *saved_mean,
|
||||
float *saved_invstd, float *dy_sum_red, float *dot_p_red,
|
||||
float *scale, float *dscale, float *dbias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGradPost<float, half, float>(size_t N, size_t C, size_t H, size_t W, const float *x_input,
|
||||
const float *dy, float *dx, float *saved_mean,
|
||||
float *saved_invstd, float *dy_sum_red, float *dot_p_red,
|
||||
half *scale, half *dscale, half *dbias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGradPost<half, half, float>(size_t N, size_t C, size_t H, size_t W, const half *x_input,
|
||||
const half *dy, half *dx, float *saved_mean,
|
||||
float *saved_invstd, float *dy_sum_red, float *dot_p_red,
|
||||
half *scale, half *dscale, half *dbias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGradPost<float, float, half>(size_t N, size_t C, size_t H, size_t W, const float *x_input,
|
||||
const float *dy, float *dx, half *saved_mean,
|
||||
half *saved_invstd, float *dy_sum_red, float *dot_p_red,
|
||||
float *scale, float *dscale, float *dbias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGradPost<half, float, half>(size_t N, size_t C, size_t H, size_t W, const half *x_input,
|
||||
const half *dy, half *dx, half *saved_mean,
|
||||
half *saved_invstd, float *dy_sum_red, float *dot_p_red,
|
||||
float *scale, float *dscale, float *dbias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGradPost<float, half, half>(size_t N, size_t C, size_t H, size_t W, const float *x_input,
|
||||
const float *dy, float *dx, half *saved_mean,
|
||||
half *saved_invstd, float *dy_sum_red, float *dot_p_red,
|
||||
half *scale, half *dscale, half *dbias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGradPost<half, half, half>(size_t N, size_t C, size_t H, size_t W, const half *x_input,
|
||||
const half *dy, half *dx, half *saved_mean, half *saved_invstd,
|
||||
float *dy_sum_red, float *dot_p_red, half *scale, half *dscale,
|
||||
half *dbias, float epsilon, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
// /**
|
||||
// * Copyright 2021 Huawei Technologies Co., Ltd
|
||||
// *
|
||||
// * Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// * you may not use this file except in compliance with the License.
|
||||
// * You may obtain a copy of the License at
|
||||
// *
|
||||
// * http://www.apache.org/licenses/LICENSE-2.0
|
||||
// *
|
||||
// * Unless required by applicable law or agreed to in writing, software
|
||||
// * distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// * See the License for the specific language governing permissions and
|
||||
// * limitations under the License.
|
||||
// */
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_GRAD_IMPL_CUH
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_GRAD_IMPL_CUH
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T, typename G>
|
||||
void CalSyncBatchNormGradPre(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, G *saved_mean,
|
||||
G *invstd_saved, float *dy_sum_local, float *dot_p_local, cudaStream_t cuda_stream);
|
||||
template <typename T, typename S, typename G>
|
||||
void CalSyncBatchNormGradPost(size_t N, size_t C, size_t H, size_t W, const T *x_input, const T *dy, T *dx,
|
||||
G *saved_mean, G *invstd_saved, float *dy_sum_red, float *dot_p_red, S *scale, S *dscale,
|
||||
S *dbias, float epsilon, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_GRAD_IMPL_CUH
|
|
@ -0,0 +1,248 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_impl.cuh"
|
||||
|
||||
const int kWarpSize = 32;
|
||||
const int kNumWarps = 32;
|
||||
|
||||
__inline__ __device__ float HalfFloatInputConvert(const half val) { return __half2float(val); }
|
||||
__inline__ __device__ float HalfFloatInputConvert(const float val) { return val; }
|
||||
__inline__ __device__ void HalfFloatOutputAssign(const float val, float *arr, int idx) { arr[idx] = val; }
|
||||
__inline__ __device__ void HalfFloatOutputAssign(const float val, half *arr, int idx) { arr[idx] = __float2half(val); }
|
||||
|
||||
template <typename T>
|
||||
__global__ void SyncBatchNormPre(size_t N, size_t C, size_t H, size_t W, const T *input, int *output_n,
|
||||
float *output_mean, float *output_invstd, float epsilon) {
|
||||
// block level memory
|
||||
__shared__ float shared_mean[kNumWarps];
|
||||
__shared__ float shared_var[kNumWarps];
|
||||
__shared__ int shared_n[kNumWarps];
|
||||
|
||||
int warpId = threadIdx.x / kWarpSize; // threads execute in warps of 32
|
||||
int laneId = threadIdx.x % kWarpSize;
|
||||
int plane = blockIdx.x;
|
||||
int plane_size = N * H * W;
|
||||
if (threadIdx.x < kNumWarps) {
|
||||
shared_mean[threadIdx.x] = static_cast<float>(0);
|
||||
shared_var[threadIdx.x] = static_cast<float>(0);
|
||||
}
|
||||
// ensure all 0 init complete across all values
|
||||
__syncthreads();
|
||||
|
||||
// agg values
|
||||
float avg = 0;
|
||||
float var_n = 0;
|
||||
int n = 0;
|
||||
|
||||
// individual thread level reduction
|
||||
for (int x = threadIdx.x; x < plane_size; x += blockDim.x) {
|
||||
int index = (x / (H * W) * C * H * W) + (plane * H * W) + (x % (H * W));
|
||||
float input_val = HalfFloatInputConvert(input[index]);
|
||||
float d1 = input_val - avg;
|
||||
n++;
|
||||
avg = avg + (d1 / n);
|
||||
var_n = var_n + (d1 * (input_val - avg));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Reduce every warp to a single value
|
||||
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
|
||||
float other_avg = __shfl_down_sync(0xffffffff, avg, offset);
|
||||
float other_n = __shfl_down_sync(0xffffffff, n, offset);
|
||||
float div_factor = 1.0 / fmaxf(1.0, n + other_n);
|
||||
float other_var_n = __shfl_down_sync(0xffffffff, var_n, offset);
|
||||
var_n += other_var_n + (avg - other_avg) * (avg - other_avg) * n * other_n * div_factor;
|
||||
avg = (n * avg + other_n * other_avg) * div_factor;
|
||||
n += other_n;
|
||||
}
|
||||
__syncwarp();
|
||||
if (laneId == 0) {
|
||||
// lane 0 for every warp moves value
|
||||
shared_n[warpId] = n;
|
||||
shared_mean[warpId] = avg;
|
||||
shared_var[warpId] = var_n;
|
||||
// now one value per warp
|
||||
}
|
||||
// second reduction to reduce all warps into a single value
|
||||
__syncthreads();
|
||||
if (warpId == 0) {
|
||||
n = shared_n[laneId];
|
||||
avg = shared_mean[laneId];
|
||||
var_n = shared_var[laneId];
|
||||
__syncwarp();
|
||||
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
|
||||
int other_n = __shfl_down_sync(0xffffffff, n, offset);
|
||||
float other_avg = __shfl_down_sync(0xffffffff, avg, offset);
|
||||
float div_factor = 1.0 / fmaxf(1.0, n + other_n);
|
||||
float other_var_n = __shfl_down_sync(0xffffffff, var_n, offset);
|
||||
var_n += other_var_n + (avg - other_avg) * (avg - other_avg) * n * other_n * div_factor;
|
||||
avg = (n * avg + other_n * other_avg) * div_factor;
|
||||
n += other_n;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
output_n[plane] = n;
|
||||
output_mean[plane] = avg;
|
||||
output_invstd[plane] = static_cast<float>(1) / sqrt((var_n / plane_size) + epsilon);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename G>
|
||||
__global__ void SyncBatchNormGather(size_t N, size_t C, size_t H, size_t W, int *counts_global, float *means_global,
|
||||
float *invstds_global, int *counts_local, float *means_local, float *invstds_local,
|
||||
T *running_mean_output, T *running_var_output, G *running_mean_input,
|
||||
G *running_var_input, float epsilon, float momentum, size_t group_rank,
|
||||
size_t group_size) {
|
||||
int feature_size = C;
|
||||
int world_size = group_size;
|
||||
for (size_t C_ix = blockIdx.x * blockDim.x + threadIdx.x; C_ix < C; C_ix += blockDim.x * gridDim.x) {
|
||||
float avg = 0;
|
||||
float var_n = 0;
|
||||
float n = 0;
|
||||
for (int N_ix = 0; N_ix < world_size; N_ix++) {
|
||||
int count = counts_global[N_ix * feature_size + C_ix];
|
||||
float mean_ = means_global[N_ix * feature_size + C_ix];
|
||||
float std = static_cast<float>(1) / invstds_global[N_ix * feature_size + C_ix];
|
||||
float var_n_ = (std * std - epsilon) * count;
|
||||
float div_factor = 1.0 / fmaxf(1.0, n + count);
|
||||
var_n += var_n_ + (avg - mean_) * (avg - mean_) * n * count * div_factor;
|
||||
avg = n * div_factor * avg + count * div_factor * mean_;
|
||||
n += count;
|
||||
}
|
||||
means_local[C_ix] = avg;
|
||||
invstds_local[C_ix] = static_cast<float>(1) / sqrt((var_n / n) + epsilon);
|
||||
HalfFloatOutputAssign(((1 - momentum) * HalfFloatInputConvert(running_mean_input[C_ix]) + momentum * avg),
|
||||
running_mean_output, C_ix);
|
||||
float unbiasedVar = 0.0;
|
||||
if (n != 0) { // not strictly required since pipeline does not allow empty inputs
|
||||
unbiasedVar = var_n / n;
|
||||
}
|
||||
HalfFloatOutputAssign(((1 - momentum) * HalfFloatInputConvert(running_var_input[C_ix]) + momentum * unbiasedVar),
|
||||
running_var_output, C_ix);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void SyncBatchNormPost(size_t N, size_t C, size_t H, size_t W, const T *input, T *output, float *means_local,
|
||||
float *invstds_local, S *scale, S *bias, float epsilon) {
|
||||
int size = N * C * H * W;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
int block_num = (pos / W) / H; // which of N * C blocks
|
||||
int plane = block_num % C;
|
||||
float scale_plane = HalfFloatInputConvert(scale[plane]);
|
||||
float bias_plane = HalfFloatInputConvert(bias[plane]);
|
||||
float mean_plane = means_local[plane];
|
||||
float invstd_plane = invstds_local[plane];
|
||||
float input_val = HalfFloatInputConvert(input[pos]);
|
||||
HalfFloatOutputAssign(scale_plane * (input_val - mean_plane) * invstd_plane + bias_plane, output, pos);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void SyncBatchNormPostBiasScale(size_t C, S *scale, S *bias, S *output_scale, S *output_bias) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < C; pos += blockDim.x * gridDim.x) {
|
||||
output_bias[pos] = bias[pos];
|
||||
output_scale[pos] = scale[pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalSyncBatchNormPre(size_t N, size_t C, size_t H, size_t W, const T *input, int *output_n, float *output_mean,
|
||||
float *output_var, float epsilon, cudaStream_t cuda_stream) {
|
||||
SyncBatchNormPre<<<C, GET_THREADS, 0, cuda_stream>>>(N, C, H, W, input, output_n, output_mean, output_var, epsilon);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename G>
|
||||
void CalSyncBatchNormGather(size_t N, size_t C, size_t H, size_t W, int *counts_global, float *means_global,
|
||||
float *invstds_global, int *counts_local, float *means_local, float *invstds_local,
|
||||
T *running_mean_output, T *running_var_output, G *running_mean_input, G *running_var_input,
|
||||
float epsilon, float momentum, size_t group_rank, size_t group_size,
|
||||
cudaStream_t cuda_stream) {
|
||||
SyncBatchNormGather<<<GET_BLOCKS(C), GET_THREADS, 0, cuda_stream>>>(
|
||||
N, C, H, W, counts_global, means_global, invstds_global, counts_local, means_local, invstds_local,
|
||||
running_mean_output, running_var_output, running_mean_input, running_var_input, epsilon, momentum, group_rank,
|
||||
group_size);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalSyncBatchNormPost(size_t N, size_t C, size_t H, size_t W, const T *input, T *output, float *means_local,
|
||||
float *invstds_local, S *scale, S *bias, S *output_scale, S *output_bias, float epsilon,
|
||||
cudaStream_t cuda_stream) {
|
||||
SyncBatchNormPost<<<GET_BLOCKS(N * C * H * W), GET_THREADS, 0, cuda_stream>>>(N, C, H, W, input, output, means_local,
|
||||
invstds_local, scale, bias, epsilon);
|
||||
SyncBatchNormPostBiasScale<<<1, std::min(C, static_cast<size_t>(GET_THREADS)), 0, cuda_stream>>>(
|
||||
C, scale, bias, output_scale, output_bias);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalSyncBatchNormPre<float>(size_t N, size_t C, size_t H, size_t W, const float *input, int *output_n,
|
||||
float *output_mean, float *output_var, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormPre<half>(size_t N, size_t C, size_t H, size_t W, const half *input, int *output_n,
|
||||
float *output_mean, float *output_var, float epsilon, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalSyncBatchNormGather<float, float>(size_t N_, size_t C_, size_t H_, size_t W_, int *counts_global,
|
||||
float *means_global, float *invstds_global, int *counts_local,
|
||||
float *means_local, float *invstds_local, float *running_mean_output,
|
||||
float *running_var_output, float *running_mean_input,
|
||||
float *running_var_input, float epsilon, float momentum,
|
||||
size_t group_rank, size_t group_size, cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGather<float, half>(size_t N_, size_t C_, size_t H_, size_t W_, int *counts_global,
|
||||
float *means_global, float *invstds_global, int *counts_local,
|
||||
float *means_local, float *invstds_local, float *running_mean_output,
|
||||
float *running_var_output, half *running_mean_input,
|
||||
half *running_var_input, float epsilon, float momentum,
|
||||
size_t group_rank, size_t group_size, cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGather<half, float>(size_t N_, size_t C_, size_t H_, size_t W_, int *counts_global,
|
||||
float *means_global, float *invstds_global, int *counts_local,
|
||||
float *means_local, float *invstds_local, half *running_mean_output,
|
||||
half *running_var_output, float *running_mean_input,
|
||||
float *running_var_input, float epsilon, float momentum,
|
||||
size_t group_rank, size_t group_size, cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormGather<half, half>(size_t N_, size_t C_, size_t H_, size_t W_, int *counts_global,
|
||||
float *means_global, float *invstds_global, int *counts_local,
|
||||
float *means_local, float *invstds_local, half *running_mean_output,
|
||||
half *running_var_output, half *running_mean_input,
|
||||
half *running_var_input, float epsilon, float momentum,
|
||||
size_t group_rank, size_t group_size, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalSyncBatchNormPost<float, float>(size_t N, size_t C, size_t H, size_t W, const float *input,
|
||||
float *output, float *means_local, float *invstds_local, float *scale,
|
||||
float *bias, float *output_scale, float *output_bias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormPost<float, half>(size_t N, size_t C, size_t H, size_t W, const float *input,
|
||||
float *output, float *means_local, float *invstds_local, half *scale,
|
||||
half *bias, half *output_scale, half *output_bias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormPost<half, float>(size_t N, size_t C, size_t H, size_t W, const half *input, half *output,
|
||||
float *means_local, float *invstds_local, float *scale, float *bias,
|
||||
float *output_scale, float *output_bias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSyncBatchNormPost<half, half>(size_t N, size_t C, size_t H, size_t W, const half *input, half *output,
|
||||
float *means_local, float *invstds_local, half *scale, half *bias,
|
||||
half *output_scale, half *output_bias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,33 @@
|
|||
// /**
|
||||
// * Copyright 2021 Huawei Technologies Co., Ltd
|
||||
// *
|
||||
// * Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// * you may not use this file except in compliance with the License.
|
||||
// * You may obtain a copy of the License at
|
||||
// *
|
||||
// * http://www.apache.org/licenses/LICENSE-2.0
|
||||
// *
|
||||
// * Unless required by applicable law or agreed to in writing, software
|
||||
// * distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// * See the License for the specific language governing permissions and
|
||||
// * limitations under the License.
|
||||
// */
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_IMPL_CUH
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_IMPL_CUH
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void CalSyncBatchNormPre(size_t N, size_t C, size_t H, size_t W, const T *input, int *output_n, float *means_local,
|
||||
float *invstds_local, float epsilon, cudaStream_t cuda_stream);
|
||||
template <typename T, typename G>
|
||||
void CalSyncBatchNormGather(size_t N, size_t C, size_t H, size_t W, int *counts_global, float *means_global,
|
||||
float *invstds_global, int *counts_local, float *means_local, float *invstds_local,
|
||||
T *running_mean_output, T *running_var_output, G *running_mean_input, G *running_var_input,
|
||||
float epsilon, float momentum, size_t group_rank, size_t group_size,
|
||||
cudaStream_t cuda_stream);
|
||||
template <typename T, typename S>
|
||||
void CalSyncBatchNormPost(size_t N, size_t C, size_t H, size_t W, const T *input, T *output, float *means_local,
|
||||
float *invstds_local, S *scale, S *bias, S *output_scale, S *output_bias, float epsilon,
|
||||
cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SYNC_BATCH_NORM_IMPL_CUH
|
|
@ -0,0 +1,126 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nccl/sync_batch_norm_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SyncBatchNormGpuKernel, float, float, float)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SyncBatchNormGpuKernel, half, float, float)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SyncBatchNormGpuKernel, float, half, float)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SyncBatchNormGpuKernel, half, half, float)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SyncBatchNormGpuKernel, float, float, half)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SyncBatchNormGpuKernel, half, float, half)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SyncBatchNormGpuKernel, float, half, half)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SyncBatchNormGpuKernel, half, half, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,246 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S, typename G>
|
||||
class SyncBatchNormGpuKernel : public NcclGpuKernel {
|
||||
public:
|
||||
SyncBatchNormGpuKernel() { ResetResource(); }
|
||||
~SyncBatchNormGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *x = GetDeviceAddress<T>(inputs, 0);
|
||||
S *scale = GetDeviceAddress<S>(inputs, 1);
|
||||
S *bias = GetDeviceAddress<S>(inputs, 2);
|
||||
G *running_mean_input = GetDeviceAddress<G>(inputs, 3);
|
||||
G *running_variance_input = GetDeviceAddress<G>(inputs, 4);
|
||||
|
||||
float *means_local = GetDeviceAddress<float>(workspace, 0); // per device
|
||||
float *invstds_local = GetDeviceAddress<float>(workspace, 1);
|
||||
int *counts_local = GetDeviceAddress<int>(workspace, 2);
|
||||
int *counts_global = GetDeviceAddress<int>(workspace, 3); // gathered values from all devices
|
||||
float *means_global = GetDeviceAddress<float>(workspace, 4);
|
||||
float *invstds_global = GetDeviceAddress<float>(workspace, 5);
|
||||
|
||||
T *y = GetDeviceAddress<T>(outputs, 0);
|
||||
S *output_scale = GetDeviceAddress<S>(outputs, 1);
|
||||
S *output_bias = GetDeviceAddress<S>(outputs, 2);
|
||||
T *output_running_mean = GetDeviceAddress<T>(outputs, 3);
|
||||
T *output_running_variance = GetDeviceAddress<T>(outputs, 4);
|
||||
|
||||
// aggregate means and invstd on each device locally
|
||||
CalSyncBatchNormPre(N_, C_, H_, W_, x, counts_local, means_local, invstds_local, epsilon_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
// gather values from all devices together
|
||||
LaunchAllGather(means_local, means_global, stream_ptr);
|
||||
LaunchAllGather(invstds_local, invstds_global, stream_ptr);
|
||||
LaunchAllGather(counts_local, counts_global, stream_ptr);
|
||||
// reducing gathered values on each device and deal with running means and variance
|
||||
CalSyncBatchNormGather(N_, C_, H_, W_, counts_global, means_global, invstds_global, counts_local, means_local,
|
||||
invstds_local, output_running_mean, output_running_variance, running_mean_input,
|
||||
running_variance_input, epsilon_, momentum_, group_rank_, group_size_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalSyncBatchNormPost(N_, C_, H_, W_, x, y, means_local, invstds_local, scale, bias, output_scale, output_bias,
|
||||
epsilon_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto root_rank = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrRootRank);
|
||||
if (root_rank) {
|
||||
root_ = static_cast<int>(GetValue<int64_t>(root_rank));
|
||||
}
|
||||
nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
|
||||
group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 5) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but SyncBatchNorm needs 5 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 5) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but SyncBatchNorm needs 5 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (CHECK_NULL_INPUT(input_shape)) {
|
||||
MS_LOG(WARNING) << "SyncBatchNorm input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
auto input_shape_dims = input_shape.size();
|
||||
if (input_shape_dims != 4 && input_shape_dims != 2) {
|
||||
MS_LOG(EXCEPTION) << "Tensor shape is " << input_shape.size()
|
||||
<< ", SyncBatchNormGpuKernel input should be 2D or 4D";
|
||||
}
|
||||
input_size_ = 1;
|
||||
for (auto dim : input_shape) {
|
||||
input_size_ *= dim;
|
||||
}
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
momentum_ = GetAttr<float>(kernel_node, "momentum");
|
||||
output_size_ = input_size_;
|
||||
output_size_ = output_size_ * sizeof(T);
|
||||
input_size_ = input_size_ * sizeof(T);
|
||||
param_count_ = input_shape[1]; // C is number of features
|
||||
param_size_S_ = param_count_ * sizeof(S); // will be second/third template
|
||||
param_size_G_input_ = param_count_ * sizeof(G);
|
||||
param_size_G_output_ = param_count_ * sizeof(T);
|
||||
workspace_size_ = param_count_; // specific size computed in InitSizeLists()
|
||||
N_ = input_shape[0];
|
||||
C_ = input_shape[1];
|
||||
if (input_shape_dims == 2) {
|
||||
// NC -> N,C,1,1 transform input dims
|
||||
H_ = 1;
|
||||
W_ = 1;
|
||||
} else {
|
||||
H_ = input_shape[2];
|
||||
W_ = input_shape[3];
|
||||
}
|
||||
// MULTI DEVICE SPECIFICS
|
||||
group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
|
||||
MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_;
|
||||
auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id");
|
||||
if (comm_stream_attr) {
|
||||
comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr));
|
||||
MS_EXCEPTION_IF_NULL(comm_stream_);
|
||||
}
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
// Get group size
|
||||
auto get_group_size_funcptr =
|
||||
reinterpret_cast<GetGroupRanks>(dlsym(const_cast<void *>(collective_handle_), "GetGroupRanks"));
|
||||
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
|
||||
std::vector<int> group_ranks = (*get_group_size_funcptr)(group_name_);
|
||||
group_size_ = group_ranks.size();
|
||||
// // Get device rank ID in group
|
||||
using GetLocalRankId = device::gpu::GetLocalRankId;
|
||||
auto get_local_rank_funcptr =
|
||||
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
|
||||
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
|
||||
group_rank_ = IntToUint((*get_local_rank_funcptr)());
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
workspace_size_ = 0;
|
||||
momentum_ = 0;
|
||||
epsilon_ = 10e-5;
|
||||
param_size_S_ = 0;
|
||||
param_size_G_input_ = 0;
|
||||
param_size_G_output_ = 0;
|
||||
param_count_ = 0;
|
||||
N_ = 0;
|
||||
C_ = 0;
|
||||
H_ = 0;
|
||||
W_ = 0;
|
||||
root_ = 0;
|
||||
collective_handle_ = nullptr;
|
||||
comm_stream_ = nullptr;
|
||||
nccl_reduce_type_ = ncclSum;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_); // input x
|
||||
input_size_list_.push_back(param_size_S_); // scale
|
||||
input_size_list_.push_back(param_size_S_); // bias
|
||||
input_size_list_.push_back(param_size_G_input_); // running mean
|
||||
input_size_list_.push_back(param_size_G_input_); // running variance
|
||||
output_size_list_.push_back(output_size_); // output
|
||||
output_size_list_.push_back(param_size_S_); // save scale
|
||||
output_size_list_.push_back(param_size_S_); // reserve space
|
||||
output_size_list_.push_back(param_size_G_output_); // save mean
|
||||
output_size_list_.push_back(param_size_G_output_); // save variance
|
||||
// local mean/variance data - per device
|
||||
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // mean_local
|
||||
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // invstd_local
|
||||
workspace_size_list_.push_back(workspace_size_ * sizeof(int)); // count_local
|
||||
// global mean/variance data - for all devices
|
||||
workspace_size_list_.push_back(workspace_size_ * sizeof(int) * group_size_); // gathered mean
|
||||
workspace_size_list_.push_back(workspace_size_ * sizeof(float) * group_size_); // gathered invstd
|
||||
workspace_size_list_.push_back(workspace_size_ * sizeof(float) * group_size_); // gathered count
|
||||
}
|
||||
|
||||
private:
|
||||
// GetTypeID functions return the correct typeID for input template
|
||||
// Allow for a single templated LaunchAllGather function
|
||||
mindspore::TypeId GetTypeID(float *input) { return kNumberTypeFloat32; }
|
||||
mindspore::TypeId GetTypeID(int *input) { return kNumberTypeInt32; }
|
||||
template <typename gather_type>
|
||||
void LaunchAllGather(gather_type *input_addr, gather_type *output_addr, void *stream_ptr) {
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto all_gather_funcptr = reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather"));
|
||||
MS_EXCEPTION_IF_NULL(all_gather_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
(*all_gather_funcptr)(input_addr, output_addr, C_, nccl_dtype(GetTypeID(input_addr)), stream, group_name_),
|
||||
"ncclAllGather failed");
|
||||
}
|
||||
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
float momentum_;
|
||||
float epsilon_;
|
||||
size_t param_size_S_;
|
||||
size_t param_size_G_input_;
|
||||
size_t param_size_G_output_;
|
||||
size_t param_count_;
|
||||
size_t N_;
|
||||
size_t C_;
|
||||
size_t H_;
|
||||
size_t W_;
|
||||
size_t group_size_;
|
||||
size_t group_rank_;
|
||||
ncclRedOp_t nccl_reduce_type_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
// NCCL
|
||||
string group_name_;
|
||||
int root_;
|
||||
const void *collective_handle_;
|
||||
cudaStream_t comm_stream_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GPU_KERNEL_H_
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nccl/sync_batch_norm_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SyncBatchNormGradGpuKernel, float, float, float)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SyncBatchNormGradGpuKernel, half, float, float)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SyncBatchNormGradGpuKernel, float, half, float)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SyncBatchNormGradGpuKernel, half, half, float)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SyncBatchNormGradGpuKernel, float, float, half)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SyncBatchNormGradGpuKernel, half, float, half)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SyncBatchNormGradGpuKernel, float, half, half)
|
||||
MS_REG_GPU_KERNEL_THREE(SyncBatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
SyncBatchNormGradGpuKernel, half, half, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,209 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sync_batch_norm_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S, typename G>
|
||||
class SyncBatchNormGradGpuKernel : public NcclGpuKernel {
|
||||
public:
|
||||
SyncBatchNormGradGpuKernel() { ResetResource(); }
|
||||
~SyncBatchNormGradGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *dy = GetDeviceAddress<T>(inputs, 0);
|
||||
T *x_input = GetDeviceAddress<T>(inputs, 1);
|
||||
S *scale = GetDeviceAddress<S>(inputs, 2);
|
||||
G *saved_mean = GetDeviceAddress<G>(inputs, 3);
|
||||
G *saved_variance = GetDeviceAddress<G>(inputs, 4);
|
||||
float *dy_sum_local = GetDeviceAddress<float>(workspace, 0);
|
||||
float *dot_p_local = GetDeviceAddress<float>(workspace, 1);
|
||||
float *dy_sum_red = GetDeviceAddress<float>(workspace, 2);
|
||||
float *dot_p_red = GetDeviceAddress<float>(workspace, 3);
|
||||
T *dx = GetDeviceAddress<T>(outputs, 0);
|
||||
S *dscale = GetDeviceAddress<S>(outputs, 1);
|
||||
S *dbias = GetDeviceAddress<S>(outputs, 2);
|
||||
// aggregate interim values on each device locally
|
||||
CalSyncBatchNormGradPre(N_, C_, H_, W_, x_input, dy, saved_mean, saved_variance, dy_sum_local, dot_p_local,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
// reduce values across devices
|
||||
LaunchAllReduce(dy_sum_local, dy_sum_red, stream_ptr);
|
||||
LaunchAllReduce(dot_p_local, dot_p_red, stream_ptr);
|
||||
// Aggregate and compute output
|
||||
CalSyncBatchNormGradPost(N_, C_, H_, W_, x_input, dy, dx, saved_mean, saved_variance, dy_sum_red, dot_p_red, scale,
|
||||
dscale, dbias, epsilon_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto root_rank = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrRootRank);
|
||||
if (root_rank) {
|
||||
root_ = static_cast<int>(GetValue<int64_t>(root_rank));
|
||||
}
|
||||
nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
|
||||
group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 5) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but SyncBatchNormGrad needs 5 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 3) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but SyncBatchNormGrad needs 5 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (CHECK_NULL_INPUT(input_shape)) {
|
||||
MS_LOG(WARNING) << "SyncBatchNormGrad input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
auto input_shape_dims = input_shape.size();
|
||||
if (input_shape_dims != 4 && input_shape_dims != 2) {
|
||||
MS_LOG(EXCEPTION) << "Tensor shape is " << input_shape.size()
|
||||
<< ", SyncBatchNormGpuGrad input should be 2D or 4D";
|
||||
}
|
||||
input_size_ = 1;
|
||||
for (auto dim : input_shape) {
|
||||
input_size_ *= dim;
|
||||
}
|
||||
output_size_ = input_size_;
|
||||
output_size_ = output_size_ * sizeof(T);
|
||||
input_size_ = input_size_ * sizeof(T);
|
||||
param_count_ = input_shape[1];
|
||||
param_size_S_ = param_count_ * sizeof(S);
|
||||
param_size_G_ = param_count_ * sizeof(G);
|
||||
N_ = input_shape[0];
|
||||
C_ = input_shape[1];
|
||||
if (input_shape_dims == 2) { // N,C,1,1 transform input
|
||||
H_ = 1;
|
||||
W_ = 1;
|
||||
} else {
|
||||
H_ = input_shape[2];
|
||||
W_ = input_shape[3];
|
||||
}
|
||||
workspace_size_ = C_;
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
// MULTIDEVICE SPECIFICS
|
||||
group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
|
||||
MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_;
|
||||
auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id");
|
||||
if (comm_stream_attr) {
|
||||
comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr));
|
||||
MS_EXCEPTION_IF_NULL(comm_stream_);
|
||||
}
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
// Get group size
|
||||
auto get_group_size_funcptr =
|
||||
reinterpret_cast<GetGroupRanks>(dlsym(const_cast<void *>(collective_handle_), "GetGroupRanks"));
|
||||
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
|
||||
std::vector<int> group_ranks = (*get_group_size_funcptr)(group_name_);
|
||||
device_count_ = group_ranks.size();
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
workspace_size_ = 0;
|
||||
epsilon_ = 10e-5; // default
|
||||
param_size_S_ = 0;
|
||||
param_size_G_ = 0;
|
||||
param_count_ = 0;
|
||||
N_ = 0;
|
||||
C_ = 0;
|
||||
H_ = 0;
|
||||
W_ = 0;
|
||||
root_ = 0;
|
||||
collective_handle_ = nullptr;
|
||||
comm_stream_ = nullptr;
|
||||
nccl_reduce_type_ = ncclSum;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_); // dy
|
||||
input_size_list_.push_back(input_size_); // x
|
||||
input_size_list_.push_back(param_size_S_); // scale
|
||||
input_size_list_.push_back(param_size_G_); // saved_mean
|
||||
input_size_list_.push_back(param_size_G_); // saved_variance
|
||||
output_size_list_.push_back(output_size_); // dx
|
||||
output_size_list_.push_back(param_size_S_); // dscale
|
||||
output_size_list_.push_back(param_size_S_); // dbias
|
||||
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // sum_dy
|
||||
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // sum_dy_xmu
|
||||
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // reduced sum_dy
|
||||
workspace_size_list_.push_back(workspace_size_ * sizeof(float)); // reduced sum_dy_xmu
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename reduce_type>
|
||||
void LaunchAllReduce(reduce_type *input_addr, reduce_type *output_addr, void *stream_ptr) {
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto all_reduce_funcptr = reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
|
||||
MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
|
||||
(*all_reduce_funcptr)(input_addr, output_addr, C_, nccl_dtype(kNumberTypeFloat32),
|
||||
nccl_reduce_type_, stream, group_name_),
|
||||
"ncclAllReduce - SyncBatchNormGrad - CUDA failed");
|
||||
}
|
||||
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
float epsilon_;
|
||||
size_t param_size_S_;
|
||||
size_t param_size_G_;
|
||||
size_t param_count_;
|
||||
size_t N_;
|
||||
size_t C_;
|
||||
size_t H_;
|
||||
size_t W_;
|
||||
size_t device_count_;
|
||||
ncclRedOp_t nccl_reduce_type_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
// NCCL
|
||||
string group_name_;
|
||||
int root_;
|
||||
const void *collective_handle_;
|
||||
cudaStream_t comm_stream_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_SYNC_BATCH_NORM_GRAD_GPU_KERNEL_H_
|
|
@ -108,9 +108,13 @@ class _BatchNorm(Cell):
|
|||
self.is_global = True
|
||||
self.group_device_num = self.rank_size
|
||||
self.device_list = [i for i in range(0, self.rank_size)]
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group0"
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group0"
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
|
||||
elif context.get_context("device_target") == "GPU":
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "nccl_world_group"
|
||||
|
||||
self.shape = P.Shape()
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
# define target and input values here
|
||||
x_fwd_input = np.array([[
|
||||
[[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]],
|
||||
[[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32)
|
||||
expect_output_fwd = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294],
|
||||
[-0.1471, 0.7706, 1.6882, 2.6059],
|
||||
[0.3118, 1.6882, 2.1471, 2.1471],
|
||||
[0.7706, 0.3118, 2.6059, -0.1471]],
|
||||
[[0.9119, 1.8518, 1.3819, -0.0281],
|
||||
[-0.0281, 0.9119, 1.3819, 1.8518],
|
||||
[2.7918, 0.4419, -0.4981, 0.9119],
|
||||
[1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32)
|
||||
grad_back = np.array([[[[1, 2, 7, 1], [4, 2, 1, 3], [1, 6, 5, 2], [2, 4, 3, 2]],
|
||||
[[9, 4, 3, 5], [1, 3, 7, 6], [5, 7, 9, 9], [1, 4, 6, 8]]]]).astype(np.float32)
|
||||
expect_output_back = np.array([[[[-0.69126546, -0.32903028, 1.9651246, -0.88445705],
|
||||
[0.6369296, -0.37732816, -0.93275493, -0.11168876],
|
||||
[-0.7878612, 1.3614, 0.8542711, -0.52222186],
|
||||
[-0.37732816, 0.5886317, -0.11168876, -0.28073236]],
|
||||
[[1.6447213, -0.38968924, -1.0174079, -0.55067265],
|
||||
[-2.4305856, -1.1751484, 0.86250514, 0.5502673],
|
||||
[0.39576983, 0.5470243, 1.1715001, 1.6447213],
|
||||
[-1.7996241, -0.7051701, 0.7080077, 0.5437813]]]]).astype(np.float32)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, c):
|
||||
super(Net, self).__init__()
|
||||
self.num_features = c
|
||||
self.eps = 1e-5
|
||||
self.momentum = 1
|
||||
self.mode = True
|
||||
self.affine = True
|
||||
self.sync_bn_op = nn.SyncBatchNorm(num_features=self.num_features,
|
||||
eps=self.eps,
|
||||
momentum=self.momentum,
|
||||
affine=self.affine,
|
||||
gamma_init='ones',
|
||||
beta_init='ones',
|
||||
moving_mean_init='ones',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=True,
|
||||
process_groups=None)
|
||||
def construct(self, input_data):
|
||||
return self.sync_bn_op(input_data)
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, input_data, sens):
|
||||
gout = self.grad(self.network)(input_data, sens)
|
||||
return gout
|
||||
|
||||
def test_sync_batch_norm_forward_fp32_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
init()
|
||||
x = x_fwd_input.copy().astype(np.float32)
|
||||
expect_output = expect_output_fwd.copy().astype(np.float32)
|
||||
overall_shape = x.shape
|
||||
error = np.ones(shape=overall_shape) * 1.0e-4
|
||||
net = Net(2)
|
||||
net.set_train()
|
||||
output = net(Tensor(x))
|
||||
diff = output.asnumpy() - expect_output
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
def test_sync_batch_norm_forward_fp16_pynative():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
init()
|
||||
x = x_fwd_input.copy().astype(np.float16)
|
||||
expect_output = expect_output_fwd.copy().astype(np.float16)
|
||||
overall_shape = x.shape
|
||||
error = np.ones(shape=overall_shape) * 1.0e-3
|
||||
net = Net(2)
|
||||
net.set_train()
|
||||
output = net(Tensor(x))
|
||||
diff = output.asnumpy() - expect_output
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
def test_sync_batch_norm_backwards_fp32_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
init()
|
||||
x = x_fwd_input.copy().astype(np.float32)
|
||||
expect_output = expect_output_back.copy().astype(np.float32)
|
||||
grad = grad_back.copy().astype(np.float32)
|
||||
overall_shape = x.shape
|
||||
error = np.ones(shape=overall_shape) * 1.0e-5
|
||||
fwd_net = Net(2)
|
||||
fwd_net.set_train()
|
||||
bn_grad = Grad(fwd_net)
|
||||
output = bn_grad(Tensor(x), Tensor(grad))
|
||||
diff = output[0].asnumpy() - expect_output
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
def test_sync_batch_norm_backwards_fp16_pynative():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
init()
|
||||
x = x_fwd_input.copy().astype(np.float16)
|
||||
expect_output = expect_output_back.copy().astype(np.float16)
|
||||
grad = grad_back.copy().astype(np.float16)
|
||||
overall_shape = x.shape
|
||||
error = np.ones(shape=overall_shape) * 1.0e-3
|
||||
fwd_net = Net(2)
|
||||
fwd_net.set_train()
|
||||
bn_grad = Grad(fwd_net)
|
||||
output = bn_grad(Tensor(x), Tensor(grad))
|
||||
diff = output[0].asnumpy() - expect_output
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_single
|
||||
def test_nccl_sync_batch_norm_1():
|
||||
cmd_str = "mpirun -n 4 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_forward_fp32_graph"
|
||||
return_code = os.system(cmd_str)
|
||||
assert return_code == 0
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_single
|
||||
def test_nccl_sync_batch_norm_2():
|
||||
cmd_str = "mpirun -n 4 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_forward_fp16_pynative"
|
||||
return_code = os.system(cmd_str)
|
||||
assert return_code == 0
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_single
|
||||
def test_nccl_sync_batch_norm_3():
|
||||
cmd_str = "mpirun -n 1 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_backwards_fp32_graph"
|
||||
return_code = os.system(cmd_str)
|
||||
assert return_code == 0
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_single
|
||||
def test_nccl_sync_batch_norm_4():
|
||||
cmd_str = "mpirun -n 1 pytest -s test_nccl_sync_batch_norm_op.py::test_sync_batch_norm_backwards_fp16_pynative"
|
||||
return_code = os.system(cmd_str)
|
||||
assert return_code == 0
|
Loading…
Reference in New Issue