!10167 BatchNormGrad infer gpu kernel

From: @jonwe
Reviewed-by: @liangchenghui
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-21 10:41:52 +08:00 committed by Gitee
commit 4477b97465
4 changed files with 185 additions and 10 deletions

View File

@ -0,0 +1,120 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#include <thrust/reduce.h>
#include <thrust/system/cuda/execution_policy.h>
#include "batchnorm_grad_impl.cuh"
#include "include/cuda_runtime.h"
const int kWarpSize = 32;
const int kBlockSize = 1024;
const int kNumWarps = 32;
template <typename T>
__global__ void BatchNormGradKernel(T *x_input, T *dy, float *scale, float *save_mean, float *save_variance, T *dx,
float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H, int W) {
__shared__ T shared_dy[kNumWarps];
__shared__ T shared_p[kNumWarps];
int warpId = threadIdx.x / kWarpSize;
int laneId = threadIdx.x % kWarpSize;
int plane = blockIdx.x;
int plane_size = N * H * W;
T invstd = static_cast<T>(1) / static_cast<T>(sqrt(save_variance[plane] + epsilon));
T scale_val = scale != nullptr ? static_cast<T>(scale[plane]) : static_cast<T>(1);
T grad_scale = invstd * scale_val;
T mean = static_cast<T>(save_mean[plane]);
T dy_sum = static_cast<T>(0);
T dot_p = static_cast<T>(0);
if (threadIdx.x < kNumWarps) {
shared_dy[threadIdx.x] = static_cast<T>(0);
shared_p[threadIdx.x] = static_cast<T>(0);
}
__syncthreads();
// Compute three values across (Batch, Height, Width) in one pass:
// 1. dx
// 2. Sum(dy)
// 3. DotProduct(x - mean, dy)
for (int x = threadIdx.x; x < plane_size; x += blockDim.x) {
int index = (x / (H * W) * C * H * W) + (plane * H * W) + (x % (H * W));
dx[index] = static_cast<T>(dy[index] * grad_scale);
dy_sum += dy[index];
dot_p += (x_input[index] - mean) * dy[index];
}
__syncthreads();
// Warp reduction
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
T other_dy = __shfl_down_sync(0xffffffff, dy_sum, offset);
T other_p = __shfl_down_sync(0xffffffff, dot_p, offset);
dy_sum += other_dy;
dot_p += other_p;
}
__syncwarp();
// Move warp-reduction result to shared memory
if (laneId == 0) {
shared_dy[warpId] = dy_sum;
shared_p[warpId] = dot_p;
}
__syncthreads();
// Shared memory reduction
// There are exactly 32 items in shared memory, can be reduced within one warp.
if (warpId == 0) {
dy_sum = shared_dy[laneId];
dot_p = shared_p[laneId];
__syncwarp();
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
T other_dy = __shfl_down_sync(0xffffffff, dy_sum, offset);
T other_p = __shfl_down_sync(0xffffffff, dot_p, offset);
dy_sum += other_dy;
dot_p += other_p;
}
__syncwarp();
}
// Compute bn_scale & bn_bias
if (threadIdx.x == 0) {
bn_scale[plane] = static_cast<T>(dot_p * invstd);
}
if (threadIdx.x == 0) {
bn_bias[plane] = static_cast<T>(dy_sum);
}
}
template <typename T>
void CalBatchNormGrad(T *x, T *dy, float *scale, float *save_mean, float *save_variance, T *dx, float *bn_scale,
float *bn_bias, double epsilon, int N, int C, int H, int W, cudaStream_t cuda_stream) {
BatchNormGradKernel<<<C, kBlockSize, 0, cuda_stream>>>(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias,
epsilon, N, C, H, W);
}
template void CalBatchNormGrad<float>(float *x, float *dy, float *scale, float *save_mean, float *save_variance,
float *dx, float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H,
int W, cudaStream_t cuda_stream);
template void CalBatchNormGrad<half>(half *x, half *dy, float *scale, float *save_mean, float *save_variance, half *dx,
float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H, int W,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,24 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalBatchNormGrad(T *x, T *dy, float *scale, float *save_mean, float *save_variance, T *dx, float *bn_scale,
float *bn_bias, double epsilon, int N, int C, int H, int W, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_

View File

@ -21,6 +21,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh"
namespace mindspore {
namespace kernel {
@ -66,16 +67,21 @@ class BatchNormGradGpuKernel : public GpuKernel {
// For CI only, reserved vars can not be unused.
MS_LOG(DEBUG) << reinterpret_cast<size_t>(reserve_1) << reinterpret_cast<size_t>(reserve_2); // NOLINT
const float alpha_data_diff = 1;
const float beta_data_diff = 0;
const float alpha_param_diff = 1;
const float beta_param_diff = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff,
&beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale,
bn_scale, bn_bias, epsilon_, save_mean, save_variance),
"Kernel Launch Failed.");
if (is_training_) {
const float alpha_data_diff = 1;
const float beta_data_diff = 0;
const float alpha_param_diff = 1;
const float beta_param_diff = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff,
&beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_,
scale, bn_scale, bn_bias, epsilon_, save_mean, save_variance),
"Kernel Launch Failed.");
} else {
CalBatchNormGrad(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias, epsilon_, batch_, channel_,
height_, width_, reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
@ -104,6 +110,7 @@ class BatchNormGradGpuKernel : public GpuKernel {
width_ = SizeToInt(shape[3]);
mode_ = CUDNN_BATCHNORM_SPATIAL;
is_training_ = GetAttr<bool>(kernel_node, "is_training");
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
CHECK_CUDNN_RET_WITH_EXCEPT(
@ -175,6 +182,7 @@ class BatchNormGradGpuKernel : public GpuKernel {
int width_;
cudnnBatchNormMode_t mode_;
bool is_training_;
double epsilon_;
bool is_null_input_;
cudnnTensorDescriptor_t x_desc_;

View File

@ -178,3 +178,26 @@ def test_train_stats_false_forward():
diff = output.asnumpy() - expect_output
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_infer_backward():
expect_output = np.array([[[[-0.3224156, -0.3840524], [1.1337637, -1.0998858]],
[[-0.1724273, -0.877854], [0.0422135, 0.5828123]],
[[-1.1006137, 1.1447179], [0.9015862, 0.5024918]]]]).astype(np.float32)
np.random.seed(1)
x_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
input_grad_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
ms_input = Tensor(x_np)
weight = Tensor(np.ones(3).astype(np.float32))
bias = Tensor(np.zeros(3).astype(np.float32))
moving_mean = Tensor(np.zeros(3).astype(np.float32))
moving_var_init = Tensor(np.ones(3).astype(np.float32))
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms_net = Batchnorm_Net(3, weight, bias, moving_mean, moving_var_init)
ms_net.set_train(False)
ms_grad = Grad(ms_net)
ms_out_grad_np = ms_grad(ms_input, Tensor(input_grad_np))
assert np.allclose(ms_out_grad_np[0].asnumpy(), expect_output)