!11785 gpu RangeOp better error handling in kernel
From: @peilin-wang Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @tom__chen
This commit is contained in:
commit
96f007ebb4
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -44,22 +44,58 @@ class DynamicRangeGpuKernel : public GpuKernel {
|
|||
T *range_delta = GetDeviceAddress<T>(inputs, 2);
|
||||
T *output_device_address = GetDeviceAddress<T>(outputs, 0);
|
||||
int64_t *output_shape_device_address = GetDeviceAddress<int64_t>(workspace, 0);
|
||||
DynamicRangeErrorCode *error_code_device_address = GetDeviceAddress<DynamicRangeErrorCode>(workspace, 1);
|
||||
|
||||
stream_ptr_ = stream_ptr;
|
||||
|
||||
CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address,
|
||||
max_output_length_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CudaValidateInputAndInferShape(range_start, range_end, range_delta, output_shape_device_address,
|
||||
error_code_device_address, max_output_length_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
DynamicRangeErrorCode error_code = DynamicRangeErrorCode::kOk;
|
||||
|
||||
CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_,
|
||||
cudaMemcpyAsync(&error_code, error_code_device_address, sizeof(DynamicRangeErrorCode),
|
||||
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy error code to host.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
|
||||
|
||||
// use workspace[0] for actual output shape, we know it must be 1d
|
||||
CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_,
|
||||
cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
"Failed to copy output_shape to host.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
|
||||
|
||||
LogExceptionIfNotOk(error_code);
|
||||
|
||||
CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address,
|
||||
error_code_device_address, max_output_length_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void LogExceptionIfNotOk(DynamicRangeErrorCode error_code) {
|
||||
switch (error_code) {
|
||||
case DynamicRangeErrorCode::kOk:
|
||||
return;
|
||||
case DynamicRangeErrorCode::kDeltaIsZero:
|
||||
MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be equal to zero";
|
||||
break;
|
||||
case DynamicRangeErrorCode::kInvalidPositiveDelta:
|
||||
MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be positive when limit < start";
|
||||
break;
|
||||
case DynamicRangeErrorCode::kInvalidNegativeDelta:
|
||||
MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be negative when limit > start";
|
||||
break;
|
||||
case DynamicRangeErrorCode::kMaxSizeExceeded:
|
||||
MS_LOG(EXCEPTION) << "gpu RangeOp memory error: the number of elements in the output exceeds maxlen";
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "gpu RangeOp unknown error";
|
||||
}
|
||||
}
|
||||
|
||||
void PostExecute() override {
|
||||
// required synchronize for PostExecute
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
|
@ -103,6 +139,7 @@ class DynamicRangeGpuKernel : public GpuKernel {
|
|||
|
||||
// this op outputs a 1d tensor, size of one int64_t is enough space to hold the shape.
|
||||
workspace_size_list_.push_back(sizeof(int64_t));
|
||||
workspace_size_list_.push_back(sizeof(DynamicRangeErrorCode));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,57 +20,90 @@
|
|||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__device__ void CheckInputs(const T &start, const T &end, const T &delta) {
|
||||
__global__ void ValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta,
|
||||
int64_t *output_shape, DynamicRangeErrorCode *error_code,
|
||||
const int64_t max_output_size) {
|
||||
T start = range_start[0];
|
||||
T end = range_end[0];
|
||||
T delta = range_delta[0];
|
||||
*error_code = DynamicRangeErrorCode::kOk;
|
||||
|
||||
if (delta == 0) {
|
||||
asm("trap;");
|
||||
*error_code = DynamicRangeErrorCode::kDeltaIsZero;
|
||||
return;
|
||||
}
|
||||
|
||||
if (start < end && delta < 0) {
|
||||
asm("trap;");
|
||||
*error_code = DynamicRangeErrorCode::kInvalidNegativeDelta;
|
||||
return;
|
||||
}
|
||||
|
||||
if (start > end && delta > 0) {
|
||||
asm("trap;");
|
||||
*error_code = DynamicRangeErrorCode::kInvalidPositiveDelta;
|
||||
return;
|
||||
}
|
||||
|
||||
if (*error_code == DynamicRangeErrorCode::kOk) {
|
||||
int64_t real_output_shape = static_cast<int64_t>(ceil(static_cast<double>(end - start) / delta));
|
||||
if (real_output_shape > max_output_size) {
|
||||
*error_code = DynamicRangeErrorCode::kMaxSizeExceeded;
|
||||
}
|
||||
*output_shape = real_output_shape;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output,
|
||||
int64_t *output_shape, const int64_t max_output_size) {
|
||||
__global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
|
||||
const int64_t max_output_size) {
|
||||
T start = range_start[0];
|
||||
T end = range_end[0];
|
||||
T delta = range_delta[0];
|
||||
|
||||
CheckInputs(start, end, delta);
|
||||
|
||||
int64_t real_output_shape = static_cast<int64_t>(ceil(static_cast<double>(end - start) / delta));
|
||||
if (real_output_shape > max_output_size) {
|
||||
asm("trap;");
|
||||
}
|
||||
*output_shape = real_output_shape;
|
||||
|
||||
size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (; gt_id < real_output_shape; gt_id += blockDim.x * gridDim.x) {
|
||||
for (; gt_id < *output_shape; gt_id += blockDim.x * gridDim.x) {
|
||||
output[gt_id] = gt_id * delta + start;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CudaValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta,
|
||||
int64_t *output_shape, DynamicRangeErrorCode *error_code,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream) {
|
||||
ValidateInputAndInferShape<<<1, 1, 0, cuda_stream>>>(range_start, range_end, range_delta, output_shape, error_code,
|
||||
max_output_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream) {
|
||||
DynamicRangeErrorCode *error_code, const int64_t max_output_size, cudaStream_t cuda_stream) {
|
||||
Range<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, cuda_stream>>>(range_start, range_end, range_delta,
|
||||
output, output_shape, max_output_size);
|
||||
}
|
||||
|
||||
template void CudaValidateInputAndInferShape<int>(const int *range_start, const int *range_end, const int *range_delta,
|
||||
int64_t *output_shape, DynamicRangeErrorCode *error_code,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
template void CudaValidateInputAndInferShape<int64_t>(const int64_t *range_start, const int64_t *range_end,
|
||||
const int64_t *range_delta, int64_t *output_shape,
|
||||
DynamicRangeErrorCode *error_code, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CudaValidateInputAndInferShape<float>(const float *range_start, const float *range_end,
|
||||
const float *range_delta, int64_t *output_shape,
|
||||
DynamicRangeErrorCode *error_code, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CudaValidateInputAndInferShape<double>(const double *range_start, const double *range_end,
|
||||
const double *range_delta, int64_t *output_shape,
|
||||
DynamicRangeErrorCode *error_code, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void CalRange<int>(const int *range_start, const int *range_end, const int *range_delta, int *output,
|
||||
int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
|
||||
int64_t *output_shape, DynamicRangeErrorCode *error_code, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalRange<int64_t>(const int64_t *range_start, const int64_t *range_end, const int64_t *range_delta,
|
||||
int64_t *output, int64_t *output_shape, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
int64_t *output, int64_t *output_shape, DynamicRangeErrorCode *error_code,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
template void CalRange<float>(const float *range_start, const float *range_end, const float *range_delta, float *output,
|
||||
int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
int64_t *output_shape, DynamicRangeErrorCode *error_code, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalRange<double>(const double *range_start, const double *range_end, const double *range_delta,
|
||||
double *output, int64_t *output_shape, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
||||
double *output, int64_t *output_shape, DynamicRangeErrorCode *error_code,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,8 +19,21 @@
|
|||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
enum class DynamicRangeErrorCode {
|
||||
kOk = 0,
|
||||
kDeltaIsZero,
|
||||
kInvalidPositiveDelta,
|
||||
kInvalidNegativeDelta,
|
||||
kMaxSizeExceeded
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void CudaValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta,
|
||||
int64_t *output_shape, DynamicRangeErrorCode *error_code,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
DynamicRangeErrorCode *error_code, const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_
|
||||
|
|
|
@ -22,12 +22,12 @@ from mindspore import Tensor
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
class RangeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
def __init__(self, maxlen=10000):
|
||||
super(RangeNet, self).__init__()
|
||||
self.range = P.Range()
|
||||
self.range = P.Range(maxlen)
|
||||
|
||||
def construct(self, s, e, d):
|
||||
return self.range(s, e, d)
|
||||
def construct(self, start, limit, delta):
|
||||
return self.range(start, limit, delta)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -91,3 +91,27 @@ def test_range_invalid_max_output_length():
|
|||
_ = P.Range(-1)
|
||||
_ = P.Range(None)
|
||||
_ = P.Range('5')
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_invalid_input():
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
range_net = RangeNet(3500)
|
||||
_ = range_net(Tensor(0, mstype.int32), Tensor(5, mstype.int32), Tensor(0, mstype.int32)).asnumpy()
|
||||
assert "delta cannot be equal to zero" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
range_net = RangeNet(2)
|
||||
_ = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy()
|
||||
assert "number of elements in the output exceeds maxlen" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
range_net = RangeNet(3500)
|
||||
_ = range_net(Tensor(20, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy()
|
||||
assert "delta cannot be positive when limit < start" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
range_net = RangeNet(3500)
|
||||
_ = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(-4, mstype.int32)).asnumpy()
|
||||
assert "delta cannot be negative when limit > start" in str(info.value)
|
||||
|
|
Loading…
Reference in New Issue