!11785 gpu RangeOp better error handling in kernel

From: @peilin-wang
Reviewed-by: @robingrosman,@tom__chen
Signed-off-by: @tom__chen
This commit is contained in:
mindspore-ci-bot 2021-01-29 21:06:52 +08:00 committed by Gitee
commit 96f007ebb4
4 changed files with 143 additions and 36 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -44,22 +44,58 @@ class DynamicRangeGpuKernel : public GpuKernel {
T *range_delta = GetDeviceAddress<T>(inputs, 2); T *range_delta = GetDeviceAddress<T>(inputs, 2);
T *output_device_address = GetDeviceAddress<T>(outputs, 0); T *output_device_address = GetDeviceAddress<T>(outputs, 0);
int64_t *output_shape_device_address = GetDeviceAddress<int64_t>(workspace, 0); int64_t *output_shape_device_address = GetDeviceAddress<int64_t>(workspace, 0);
DynamicRangeErrorCode *error_code_device_address = GetDeviceAddress<DynamicRangeErrorCode>(workspace, 1);
stream_ptr_ = stream_ptr; stream_ptr_ = stream_ptr;
CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address, CudaValidateInputAndInferShape(range_start, range_end, range_delta, output_shape_device_address,
max_output_length_, reinterpret_cast<cudaStream_t>(stream_ptr)); error_code_device_address, max_output_length_,
reinterpret_cast<cudaStream_t>(stream_ptr));
DynamicRangeErrorCode error_code = DynamicRangeErrorCode::kOk;
CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_,
cudaMemcpyAsync(&error_code, error_code_device_address, sizeof(DynamicRangeErrorCode),
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to copy error code to host.");
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
// use workspace[0] for actual output shape, we know it must be 1d // use workspace[0] for actual output shape, we know it must be 1d
CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_, CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_,
cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t), cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t),
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)), cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to copy gpu memory."); "Failed to copy output_shape to host.");
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed"); CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
LogExceptionIfNotOk(error_code);
CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address,
error_code_device_address, max_output_length_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }
void LogExceptionIfNotOk(DynamicRangeErrorCode error_code) {
switch (error_code) {
case DynamicRangeErrorCode::kOk:
return;
case DynamicRangeErrorCode::kDeltaIsZero:
MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be equal to zero";
break;
case DynamicRangeErrorCode::kInvalidPositiveDelta:
MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be positive when limit < start";
break;
case DynamicRangeErrorCode::kInvalidNegativeDelta:
MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be negative when limit > start";
break;
case DynamicRangeErrorCode::kMaxSizeExceeded:
MS_LOG(EXCEPTION) << "gpu RangeOp memory error: the number of elements in the output exceeds maxlen";
break;
default:
MS_LOG(EXCEPTION) << "gpu RangeOp unknown error";
}
}
void PostExecute() override { void PostExecute() override {
// required synchronize for PostExecute // required synchronize for PostExecute
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)), CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
@ -103,6 +139,7 @@ class DynamicRangeGpuKernel : public GpuKernel {
// this op outputs a 1d tensor, size of one int64_t is enough space to hold the shape. // this op outputs a 1d tensor, size of one int64_t is enough space to hold the shape.
workspace_size_list_.push_back(sizeof(int64_t)); workspace_size_list_.push_back(sizeof(int64_t));
workspace_size_list_.push_back(sizeof(DynamicRangeErrorCode));
return; return;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -20,57 +20,90 @@
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T> template <typename T>
__device__ void CheckInputs(const T &start, const T &end, const T &delta) { __global__ void ValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta,
int64_t *output_shape, DynamicRangeErrorCode *error_code,
const int64_t max_output_size) {
T start = range_start[0];
T end = range_end[0];
T delta = range_delta[0];
*error_code = DynamicRangeErrorCode::kOk;
if (delta == 0) { if (delta == 0) {
asm("trap;"); *error_code = DynamicRangeErrorCode::kDeltaIsZero;
return;
} }
if (start < end && delta < 0) { if (start < end && delta < 0) {
asm("trap;"); *error_code = DynamicRangeErrorCode::kInvalidNegativeDelta;
return;
} }
if (start > end && delta > 0) { if (start > end && delta > 0) {
asm("trap;"); *error_code = DynamicRangeErrorCode::kInvalidPositiveDelta;
return;
}
if (*error_code == DynamicRangeErrorCode::kOk) {
int64_t real_output_shape = static_cast<int64_t>(ceil(static_cast<double>(end - start) / delta));
if (real_output_shape > max_output_size) {
*error_code = DynamicRangeErrorCode::kMaxSizeExceeded;
}
*output_shape = real_output_shape;
} }
} }
template <typename T> template <typename T>
__global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output, __global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
int64_t *output_shape, const int64_t max_output_size) { const int64_t max_output_size) {
T start = range_start[0]; T start = range_start[0];
T end = range_end[0];
T delta = range_delta[0]; T delta = range_delta[0];
CheckInputs(start, end, delta);
int64_t real_output_shape = static_cast<int64_t>(ceil(static_cast<double>(end - start) / delta));
if (real_output_shape > max_output_size) {
asm("trap;");
}
*output_shape = real_output_shape;
size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x;
for (; gt_id < real_output_shape; gt_id += blockDim.x * gridDim.x) { for (; gt_id < *output_shape; gt_id += blockDim.x * gridDim.x) {
output[gt_id] = gt_id * delta + start; output[gt_id] = gt_id * delta + start;
} }
} }
template <typename T>
void CudaValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta,
int64_t *output_shape, DynamicRangeErrorCode *error_code,
const int64_t max_output_size, cudaStream_t cuda_stream) {
ValidateInputAndInferShape<<<1, 1, 0, cuda_stream>>>(range_start, range_end, range_delta, output_shape, error_code,
max_output_size);
}
template <typename T> template <typename T>
void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
const int64_t max_output_size, cudaStream_t cuda_stream) { DynamicRangeErrorCode *error_code, const int64_t max_output_size, cudaStream_t cuda_stream) {
Range<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, cuda_stream>>>(range_start, range_end, range_delta, Range<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, cuda_stream>>>(range_start, range_end, range_delta,
output, output_shape, max_output_size); output, output_shape, max_output_size);
} }
template void CudaValidateInputAndInferShape<int>(const int *range_start, const int *range_end, const int *range_delta,
int64_t *output_shape, DynamicRangeErrorCode *error_code,
const int64_t max_output_size, cudaStream_t cuda_stream);
template void CudaValidateInputAndInferShape<int64_t>(const int64_t *range_start, const int64_t *range_end,
const int64_t *range_delta, int64_t *output_shape,
DynamicRangeErrorCode *error_code, const int64_t max_output_size,
cudaStream_t cuda_stream);
template void CudaValidateInputAndInferShape<float>(const float *range_start, const float *range_end,
const float *range_delta, int64_t *output_shape,
DynamicRangeErrorCode *error_code, const int64_t max_output_size,
cudaStream_t cuda_stream);
template void CudaValidateInputAndInferShape<double>(const double *range_start, const double *range_end,
const double *range_delta, int64_t *output_shape,
DynamicRangeErrorCode *error_code, const int64_t max_output_size,
cudaStream_t cuda_stream);
template void CalRange<int>(const int *range_start, const int *range_end, const int *range_delta, int *output, template void CalRange<int>(const int *range_start, const int *range_end, const int *range_delta, int *output,
int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream); int64_t *output_shape, DynamicRangeErrorCode *error_code, const int64_t max_output_size,
cudaStream_t cuda_stream);
template void CalRange<int64_t>(const int64_t *range_start, const int64_t *range_end, const int64_t *range_delta, template void CalRange<int64_t>(const int64_t *range_start, const int64_t *range_end, const int64_t *range_delta,
int64_t *output, int64_t *output_shape, const int64_t max_output_size, int64_t *output, int64_t *output_shape, DynamicRangeErrorCode *error_code,
cudaStream_t cuda_stream); const int64_t max_output_size, cudaStream_t cuda_stream);
template void CalRange<float>(const float *range_start, const float *range_end, const float *range_delta, float *output, template void CalRange<float>(const float *range_start, const float *range_end, const float *range_delta, float *output,
int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream); int64_t *output_shape, DynamicRangeErrorCode *error_code, const int64_t max_output_size,
cudaStream_t cuda_stream);
template void CalRange<double>(const double *range_start, const double *range_end, const double *range_delta, template void CalRange<double>(const double *range_start, const double *range_end, const double *range_delta,
double *output, int64_t *output_shape, const int64_t max_output_size, double *output, int64_t *output_shape, DynamicRangeErrorCode *error_code,
cudaStream_t cuda_stream); const int64_t max_output_size, cudaStream_t cuda_stream);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,8 +19,21 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
enum class DynamicRangeErrorCode {
kOk = 0,
kDeltaIsZero,
kInvalidPositiveDelta,
kInvalidNegativeDelta,
kMaxSizeExceeded
};
template <typename T>
void CudaValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta,
int64_t *output_shape, DynamicRangeErrorCode *error_code,
const int64_t max_output_size, cudaStream_t cuda_stream);
template <typename T> template <typename T>
void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
const int64_t max_output_size, cudaStream_t cuda_stream); DynamicRangeErrorCode *error_code, const int64_t max_output_size, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_

View File

@ -22,12 +22,12 @@ from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
class RangeNet(nn.Cell): class RangeNet(nn.Cell):
def __init__(self): def __init__(self, maxlen=10000):
super(RangeNet, self).__init__() super(RangeNet, self).__init__()
self.range = P.Range() self.range = P.Range(maxlen)
def construct(self, s, e, d): def construct(self, start, limit, delta):
return self.range(s, e, d) return self.range(start, limit, delta)
@pytest.mark.level0 @pytest.mark.level0
@ -91,3 +91,27 @@ def test_range_invalid_max_output_length():
_ = P.Range(-1) _ = P.Range(-1)
_ = P.Range(None) _ = P.Range(None)
_ = P.Range('5') _ = P.Range('5')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_range_invalid_input():
with pytest.raises(RuntimeError) as info:
range_net = RangeNet(3500)
_ = range_net(Tensor(0, mstype.int32), Tensor(5, mstype.int32), Tensor(0, mstype.int32)).asnumpy()
assert "delta cannot be equal to zero" in str(info.value)
with pytest.raises(RuntimeError) as info:
range_net = RangeNet(2)
_ = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy()
assert "number of elements in the output exceeds maxlen" in str(info.value)
with pytest.raises(RuntimeError) as info:
range_net = RangeNet(3500)
_ = range_net(Tensor(20, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy()
assert "delta cannot be positive when limit < start" in str(info.value)
with pytest.raises(RuntimeError) as info:
range_net = RangeNet(3500)
_ = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(-4, mstype.int32)).asnumpy()
assert "delta cannot be negative when limit > start" in str(info.value)