!19791 fix prelu weight grad accuracy error fp16 on GPU

Merge pull request !19791 from zhangbuxue/fix_prelu_weight_grad_accuracy_error_fp16_on_GPU
This commit is contained in:
i-robot 2021-07-10 02:56:57 +00:00 committed by Gitee
commit de635448af
6 changed files with 85 additions and 103 deletions

View File

@ -20,39 +20,48 @@
template <typename T>
__global__ void CalPReLUGradKernel(size_t size, size_t weight_size, size_t per_channel_size,
const T *dy, const T *x, const T *w, T *dx, T *dw) {
const T *dy, const T *x, const T *w, T *dx, float *dw_array) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
size_t index = 0;
if (weight_size != 1) {
index = (pos / per_channel_size) % weight_size;
}
T threshold = static_cast<T>(0);
dx[pos] = pos[x] <= threshold ? w[index] * dy[pos] : dy[pos];
if (pos[x] < threshold) {
MsAtomicAdd(dw + index, x[pos] * dy[pos]);
size_t channel_id = weight_size == 1 ? 0 : (pos / per_channel_size) % weight_size;
dx[pos] = pos[x] <= static_cast<T>(0) ? w[channel_id] * dy[pos] : dy[pos];
if (pos[x] < static_cast<T>(0)) {
size_t index = channel_id * blockDim.x * gridDim.x + pos;
dw_array[index] += static_cast<float>(x[pos] * dy[pos]);
}
}
}
__global__ void InitDwArrayData(size_t dw_array_size, float *dw_array) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < dw_array_size; i += blockDim.x * gridDim.x) {
dw_array[i] = 0.0;
}
}
template <typename T>
__global__ void InitDwData(size_t weight_size, T *dw) {
T init_value = static_cast<T>(0);
__global__ void ComputeDwData(size_t weight_size, size_t thread_num, const float *dw_array, T *dw) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < weight_size; i += blockDim.x * gridDim.x) {
dw[i] = init_value;
T value = 0.0;
for (size_t j = 0; j < thread_num; j++) {
value += dw_array[i * thread_num + j];
}
dw[i] = static_cast<T>(value);
}
}
template <typename T>
void CalPReLUGrad(size_t size, size_t weight_size, size_t per_channel_size,
const T *dy, const T *x, const T *w, T *dx, T *dw, cudaStream_t cuda_stream) {
InitDwData<<<GET_BLOCKS(weight_size), GET_THREADS, 0, cuda_stream>>>(weight_size, dw);
const T *dy, const T *x, const T *w, T *dx, T *dw, float *dw_array, cudaStream_t cuda_stream) {
size_t thread_num = static_cast<size_t>(GET_BLOCKS(size) * GET_THREADS);
size_t dw_array_size = weight_size * thread_num;
InitDwArrayData<<<GET_BLOCKS(dw_array_size), GET_THREADS, 0, cuda_stream>>>(dw_array_size, dw_array);
CalPReLUGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, weight_size, per_channel_size,
dy, x, w, dx, dw);
dy, x, w, dx, dw_array);
ComputeDwData<<<GET_BLOCKS(weight_size), GET_THREADS, 0, cuda_stream>>>(weight_size, thread_num, dw_array, dw);
return;
}
template void CalPReLUGrad(size_t, size_t, size_t, const float *, const float *, const float *, float *, float *,
cudaStream_t);
template void CalPReLUGrad(size_t, size_t, size_t, const half *, const half *, const half *, half *, half *,
cudaStream_t);
template void CalPReLUGrad(size_t, size_t, size_t, const float *, const float *, const float *,
float *, float *, float *, cudaStream_t);
template void CalPReLUGrad(size_t, size_t, size_t, const half *, const half *, const half *,
half *, half *, float *, cudaStream_t);

View File

@ -21,5 +21,5 @@
template <typename T>
void CalPReLUGrad(size_t input_size, size_t weight_size, size_t per_channel_size,
const T *dy, const T *x, const T *w, T *dx, T *dw, cudaStream_t cuda_stream);
const T *dy, const T *x, const T *w, T *dx, T *dw, float *dw_array, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_GRAD_H_

View File

@ -18,22 +18,18 @@
template <typename T>
__global__ void CalPReLUKernel(size_t size, size_t weight_size, size_t per_channel_size,
const T *input_addr, const T *weight_addr, T *output_addr) {
const T *input, const T *weight, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
size_t index = 0;
if (weight_size != 1) {
index = (pos / per_channel_size) % weight_size;
}
T threshold = static_cast<T>(0);
output_addr[pos] = input_addr[pos] < threshold ? weight_addr[index] * input_addr[pos] : input_addr[pos];
size_t channel_id = weight_size == 1 ? 0 : (pos / per_channel_size) % weight_size;
output[pos] = input[pos] < static_cast<T>(0) ? weight[channel_id] * input[pos] :input[pos];
}
}
template <typename T>
void CalPReLU(size_t size, size_t weight_size, size_t per_channel_size,
const T *input_addr, const T *weight_addr, T *output_addr, cudaStream_t cuda_stream) {
const T *input, const T *weight, T *output, cudaStream_t cuda_stream) {
CalPReLUKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, weight_size, per_channel_size,
input_addr, weight_addr, output_addr);
input, weight, output);
return;
}

View File

@ -21,5 +21,5 @@
template <typename T>
void CalPReLU(size_t input_size, size_t weight_size, size_t per_channel_size,
const T *input_addr, const T *weight_addr, T *output_addr, cudaStream_t cuda_stream);
const T *input, const T *weight, T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_H_

View File

@ -36,15 +36,16 @@ class PReLUGradGpuKernel : public GpuKernel {
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto *dy = GetDeviceAddress<T>(inputs, 0);
auto *x = GetDeviceAddress<T>(inputs, 1);
auto *w = GetDeviceAddress<T>(inputs, 2);
auto *dx = GetDeviceAddress<T>(outputs, 0);
auto *dw = GetDeviceAddress<T>(outputs, 1);
auto *dw_array = GetDeviceAddress<float>(workspace, 0);
CalPReLUGrad(input_length_, weight_length_, per_channel_length_, dy, x, w, dx, dw,
CalPReLUGrad(input_length_, weight_length_, per_channel_length_, dy, x, w, dx, dw, dw_array,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -84,6 +85,7 @@ class PReLUGradGpuKernel : public GpuKernel {
<< channel_num << ", but got weight shape " << weight_shape;
}
weight_length_ = weight_shape[0];
workspace_size_ = weight_length_ * IntToSize(GET_BLOCKS(input_length_) * GET_THREADS) * sizeof(float);
InitSizeLists();
return true;
}
@ -105,12 +107,14 @@ class PReLUGradGpuKernel : public GpuKernel {
input_size_list_.push_back(weight_length_ * data_size);
output_size_list_.push_back(input_length_ * data_size);
output_size_list_.push_back(weight_length_ * data_size);
workspace_size_list_.push_back(workspace_size_);
}
private:
size_t input_length_{0};
size_t weight_length_{0};
size_t per_channel_length_{0};
size_t workspace_size_{0};
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

View File

@ -51,8 +51,7 @@ def judge_result_correct(result, expect):
assert np.allclose(result, expect, rtol=1.e-2)
def test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode):
context.set_context(mode=mode)
def test_prelu(x, weight, expect_forward, expect_dx, expect_dw):
prelu_forward = PReLUOpNet()
prelu_backward = PReLUOpGradNet(prelu_forward)
forward_output = prelu_forward(x, weight)
@ -64,14 +63,14 @@ def test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode):
judge_result_correct(backward_output[1], expect_dw)
context.set_context(device_target="GPU", mode=context.GRAPH_MODE)
dtypes = [mstype.float16, mstype.float32]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_single_weight():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.arange(-10, 26).reshape((2, 3, 2, 3)) * 0.7
weight = np.array([0.6])
expect_forward = np.where(x >= 0, x, weight * x)
@ -79,23 +78,18 @@ def test_prelu_single_weight():
expect_dw = np.sum(np.where(x >= 0, 0, x)).reshape((1,))
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_multiple_weight():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.arange(-10, 26).reshape((2, 3, 2, 3)) * 0.6
weight = np.array([0.2, 0.3, 0.4])
expect_forward = np.array([[[[-1.20, -1.08, -0.96],
@ -125,23 +119,18 @@ def test_prelu_multiple_weight():
expect_dw = np.array([-27.0, -6.0, 0.0])
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_single_weight_0_D():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.array(-0.8)
weight = np.array([0.6])
expect_forward = np.array(-0.48)
@ -149,23 +138,18 @@ def test_prelu_single_weight_0_D():
expect_dw = np.array([-0.8])
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_single_weight_1_D():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.arange(-10, 26).reshape((36,)) * 0.7
weight = np.array([0.6])
expect_forward = np.where(x >= 0, x, weight * x)
@ -173,23 +157,18 @@ def test_prelu_single_weight_1_D():
expect_dw = np.sum(np.where(x >= 0, 0, x)).reshape((1,))
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_single_weight_2_D():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.arange(-10, 26).reshape((4, 9)) * 0.7
weight = np.array([0.6])
expect_forward = np.where(x >= 0, x, weight * x)
@ -197,23 +176,18 @@ def test_prelu_single_weight_2_D():
expect_dw = np.sum(np.where(x >= 0, 0, x)).reshape((1,))
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_multiple_weight_2_D():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.arange(-6, 6).reshape((3, 4)) * 0.6
weight = np.array([0.2, 0.4, 0.7, 0.9])
expect_forward = np.array([[-0.72, -1.20, -1.68, -1.62],
@ -225,10 +199,9 @@ def test_prelu_multiple_weight_2_D():
expect_dw = np.array([-4.8, -3.6, -2.4, -1.8])
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw)