forked from mindspore-Ecosystem/mindspore
!10167 BatchNormGrad infer gpu kernel
From: @jonwe Reviewed-by: @liangchenghui Signed-off-by:
This commit is contained in:
commit
4477b97465
|
@ -0,0 +1,120 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/fill.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/system/cuda/execution_policy.h>
|
||||
#include "batchnorm_grad_impl.cuh"
|
||||
#include "include/cuda_runtime.h"
|
||||
|
||||
const int kWarpSize = 32;
|
||||
const int kBlockSize = 1024;
|
||||
const int kNumWarps = 32;
|
||||
|
||||
template <typename T>
|
||||
__global__ void BatchNormGradKernel(T *x_input, T *dy, float *scale, float *save_mean, float *save_variance, T *dx,
|
||||
float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H, int W) {
|
||||
__shared__ T shared_dy[kNumWarps];
|
||||
__shared__ T shared_p[kNumWarps];
|
||||
int warpId = threadIdx.x / kWarpSize;
|
||||
int laneId = threadIdx.x % kWarpSize;
|
||||
|
||||
int plane = blockIdx.x;
|
||||
int plane_size = N * H * W;
|
||||
|
||||
T invstd = static_cast<T>(1) / static_cast<T>(sqrt(save_variance[plane] + epsilon));
|
||||
T scale_val = scale != nullptr ? static_cast<T>(scale[plane]) : static_cast<T>(1);
|
||||
T grad_scale = invstd * scale_val;
|
||||
|
||||
T mean = static_cast<T>(save_mean[plane]);
|
||||
T dy_sum = static_cast<T>(0);
|
||||
T dot_p = static_cast<T>(0);
|
||||
|
||||
if (threadIdx.x < kNumWarps) {
|
||||
shared_dy[threadIdx.x] = static_cast<T>(0);
|
||||
shared_p[threadIdx.x] = static_cast<T>(0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute three values across (Batch, Height, Width) in one pass:
|
||||
// 1. dx
|
||||
// 2. Sum(dy)
|
||||
// 3. DotProduct(x - mean, dy)
|
||||
for (int x = threadIdx.x; x < plane_size; x += blockDim.x) {
|
||||
int index = (x / (H * W) * C * H * W) + (plane * H * W) + (x % (H * W));
|
||||
dx[index] = static_cast<T>(dy[index] * grad_scale);
|
||||
dy_sum += dy[index];
|
||||
dot_p += (x_input[index] - mean) * dy[index];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Warp reduction
|
||||
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
|
||||
T other_dy = __shfl_down_sync(0xffffffff, dy_sum, offset);
|
||||
T other_p = __shfl_down_sync(0xffffffff, dot_p, offset);
|
||||
dy_sum += other_dy;
|
||||
dot_p += other_p;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Move warp-reduction result to shared memory
|
||||
if (laneId == 0) {
|
||||
shared_dy[warpId] = dy_sum;
|
||||
shared_p[warpId] = dot_p;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Shared memory reduction
|
||||
// There are exactly 32 items in shared memory, can be reduced within one warp.
|
||||
if (warpId == 0) {
|
||||
dy_sum = shared_dy[laneId];
|
||||
dot_p = shared_p[laneId];
|
||||
__syncwarp();
|
||||
for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
|
||||
T other_dy = __shfl_down_sync(0xffffffff, dy_sum, offset);
|
||||
T other_p = __shfl_down_sync(0xffffffff, dot_p, offset);
|
||||
dy_sum += other_dy;
|
||||
dot_p += other_p;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Compute bn_scale & bn_bias
|
||||
if (threadIdx.x == 0) {
|
||||
bn_scale[plane] = static_cast<T>(dot_p * invstd);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
bn_bias[plane] = static_cast<T>(dy_sum);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalBatchNormGrad(T *x, T *dy, float *scale, float *save_mean, float *save_variance, T *dx, float *bn_scale,
|
||||
float *bn_bias, double epsilon, int N, int C, int H, int W, cudaStream_t cuda_stream) {
|
||||
BatchNormGradKernel<<<C, kBlockSize, 0, cuda_stream>>>(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias,
|
||||
epsilon, N, C, H, W);
|
||||
}
|
||||
|
||||
template void CalBatchNormGrad<float>(float *x, float *dy, float *scale, float *save_mean, float *save_variance,
|
||||
float *dx, float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H,
|
||||
int W, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalBatchNormGrad<half>(half *x, half *dy, float *scale, float *save_mean, float *save_variance, half *dx,
|
||||
float *bn_scale, float *bn_bias, double epsilon, int N, int C, int H, int W,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void CalBatchNormGrad(T *x, T *dy, float *scale, float *save_mean, float *save_variance, T *dx, float *bn_scale,
|
||||
float *bn_bias, double epsilon, int N, int C, int H, int W, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMGRAD_H_
|
|
@ -21,6 +21,7 @@
|
|||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -66,16 +67,21 @@ class BatchNormGradGpuKernel : public GpuKernel {
|
|||
// For CI only, reserved vars can not be unused.
|
||||
MS_LOG(DEBUG) << reinterpret_cast<size_t>(reserve_1) << reinterpret_cast<size_t>(reserve_2); // NOLINT
|
||||
|
||||
const float alpha_data_diff = 1;
|
||||
const float beta_data_diff = 0;
|
||||
const float alpha_param_diff = 1;
|
||||
const float beta_param_diff = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff,
|
||||
&beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale,
|
||||
bn_scale, bn_bias, epsilon_, save_mean, save_variance),
|
||||
"Kernel Launch Failed.");
|
||||
if (is_training_) {
|
||||
const float alpha_data_diff = 1;
|
||||
const float beta_data_diff = 0;
|
||||
const float alpha_param_diff = 1;
|
||||
const float beta_param_diff = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff,
|
||||
&beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_,
|
||||
scale, bn_scale, bn_bias, epsilon_, save_mean, save_variance),
|
||||
"Kernel Launch Failed.");
|
||||
} else {
|
||||
CalBatchNormGrad(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias, epsilon_, batch_, channel_,
|
||||
height_, width_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
@ -104,6 +110,7 @@ class BatchNormGradGpuKernel : public GpuKernel {
|
|||
width_ = SizeToInt(shape[3]);
|
||||
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL;
|
||||
is_training_ = GetAttr<bool>(kernel_node, "is_training");
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
|
@ -175,6 +182,7 @@ class BatchNormGradGpuKernel : public GpuKernel {
|
|||
int width_;
|
||||
|
||||
cudnnBatchNormMode_t mode_;
|
||||
bool is_training_;
|
||||
double epsilon_;
|
||||
bool is_null_input_;
|
||||
cudnnTensorDescriptor_t x_desc_;
|
||||
|
|
|
@ -178,3 +178,26 @@ def test_train_stats_false_forward():
|
|||
diff = output.asnumpy() - expect_output
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_infer_backward():
|
||||
expect_output = np.array([[[[-0.3224156, -0.3840524], [1.1337637, -1.0998858]],
|
||||
[[-0.1724273, -0.877854], [0.0422135, 0.5828123]],
|
||||
[[-1.1006137, 1.1447179], [0.9015862, 0.5024918]]]]).astype(np.float32)
|
||||
np.random.seed(1)
|
||||
x_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
|
||||
input_grad_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
|
||||
ms_input = Tensor(x_np)
|
||||
weight = Tensor(np.ones(3).astype(np.float32))
|
||||
bias = Tensor(np.zeros(3).astype(np.float32))
|
||||
moving_mean = Tensor(np.zeros(3).astype(np.float32))
|
||||
moving_var_init = Tensor(np.ones(3).astype(np.float32))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
ms_net = Batchnorm_Net(3, weight, bias, moving_mean, moving_var_init)
|
||||
ms_net.set_train(False)
|
||||
ms_grad = Grad(ms_net)
|
||||
ms_out_grad_np = ms_grad(ms_input, Tensor(input_grad_np))
|
||||
assert np.allclose(ms_out_grad_np[0].asnumpy(), expect_output)
|
||||
|
|
Loading…
Reference in New Issue