From 799cb79873fdb7d87add9e0ad41e99d9b5afa918 Mon Sep 17 00:00:00 2001 From: tom__chen Date: Tue, 18 May 2021 15:40:25 -0400 Subject: [PATCH] remove synchronous error check --- .../gpu/cuda_impl/index_add_impl.cu | 41 +++++-------------- .../gpu/cuda_impl/index_add_impl.cuh | 9 ---- .../gpu/math/index_add_gpu_kernel.h | 30 +------------- tests/st/ops/gpu/test_index_add_op.py | 7 ---- 4 files changed, 12 insertions(+), 75 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cu index 16a0f9a6af8..6d62951d775 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cu @@ -13,27 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#include #include "backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" #include "runtime/device/gpu/cuda_common.h" #include "include/cuda_fp16.h" -__global__ void InitErrorCode(IndexAddErrorCode *error_code) { - *error_code = IndexAddErrorCode::kOk; -} - -__global__ void ValidateIndexValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size, - IndexAddErrorCode *error_code) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_axis_size; pos += blockDim.x * gridDim.x) { - const int idx_value = index[pos]; - if (idx_value < 0 || idx_value >= dst_axis_size) { - *error_code = IndexAddErrorCode::kIndexOutOfRange; - return; - } - } - return; -} - template __global__ void IndexAddAtomic(T *dst, const int *index, const T *src, const size_t src_size, const size_t outer_size, const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size) { @@ -41,9 +25,11 @@ __global__ void IndexAddAtomic(T *dst, const int *index, const T *src, const siz const size_t src_axis_idx = (pos / inner_size) % src_axis_size; const size_t src_outer_idx = pos / (src_axis_size * inner_size); const size_t dst_axis_idx = static_cast(index[src_axis_idx]); - const size_t dst_inner_idx = pos % inner_size; - const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx; - MsAtomicAdd(&dst[dst_idx], src[pos]); + if (dst_axis_idx >= 0 && dst_axis_idx < dst_axis_size) { + const size_t dst_inner_idx = pos % inner_size; + const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx; + MsAtomicAdd(&dst[dst_idx], src[pos]); + } } return; } @@ -55,20 +41,15 @@ __global__ void IndexAdd(T *dst, const int *index, const T *src, const size_t sr const size_t src_axis_idx = (pos / inner_size) % src_axis_size; const size_t src_outer_idx = pos / (src_axis_size * inner_size); const size_t dst_axis_idx = static_cast(index[src_axis_idx]); - const size_t dst_inner_idx = pos % inner_size; - const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx; - dst[dst_idx] += src[pos]; + if (dst_axis_idx >= 0 && dst_axis_idx < dst_axis_size) { + const size_t dst_inner_idx = pos % inner_size; + const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx; + dst[dst_idx] += src[pos]; + } } return; } -void ValidateIndexAddInputValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size, - IndexAddErrorCode *error_code, cudaStream_t cuda_stream) { - InitErrorCode<<<1, 1, 0, cuda_stream>>>(error_code); - ValidateIndexValues<<>>(index, src_axis_size, dst_axis_size, - error_code); -} - template void CalIndexAdd(T *dst, const int *index, const T *src, const size_t outer_size, const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, cudaStream_t cuda_stream) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh index a32adaeafec..ab993c93603 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh @@ -16,16 +16,7 @@ #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_ -enum class IndexAddErrorCode { - kOk = 0, - kIndexOutOfRange -}; - -void ValidateIndexAddInputValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size, - IndexAddErrorCode *error_code, cudaStream_t cuda_stream); - template void CalIndexAdd(T *dst, const int *index, const T *src, const size_t outer_size, const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock, cudaStream_t cuda_stream); - #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/index_add_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/index_add_gpu_kernel.h index 5ceb1b1822a..c8119d1d63e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/index_add_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/index_add_gpu_kernel.h @@ -35,8 +35,7 @@ class IndexAddGpuKernel : public GpuKernel { src_axis_size_(0), dst_axis_size_(0), inner_size_(0), - use_lock_(true), - check_index_bound_(true) {} + use_lock_(true) {} ~IndexAddGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -49,19 +48,6 @@ class IndexAddGpuKernel : public GpuKernel { int *index = GetDeviceAddress(inputs, 1); T *src = GetDeviceAddress(inputs, 2); T *dst_out = GetDeviceAddress(outputs, 0); - - if (check_index_bound_) { - IndexAddErrorCode *error_code_addr = GetDeviceAddress(workspace, 0); - IndexAddErrorCode error_code = IndexAddErrorCode::kOk; - ValidateIndexAddInputValues(index, src_axis_size_, dst_axis_size_, error_code_addr, - reinterpret_cast(stream_ptr)); - CHECK_CUDA_RET_WITH_ERROR(kernel_node_, - cudaMemcpyAsync(&error_code, error_code_addr, sizeof(IndexAddErrorCode), - cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), - "Failed to copy error code to host."); - CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed"); - LogExceptionIfNotOk(error_code); - } CalIndexAdd(dst, index, src, outer_size_, src_axis_size_, dst_axis_size_, inner_size_, use_lock_, reinterpret_cast(stream_ptr)); CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, @@ -119,22 +105,9 @@ class IndexAddGpuKernel : public GpuKernel { input_size_list_.push_back(index_size_); input_size_list_.push_back(src_size_); output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(sizeof(IndexAddErrorCode)); } private: - void LogExceptionIfNotOk(IndexAddErrorCode error_code) { - switch (error_code) { - case IndexAddErrorCode::kOk: - return; - case IndexAddErrorCode::kIndexOutOfRange: - MS_LOG(EXCEPTION) << "gpu IndexAdd op error: values of index tensor is out of range"; - break; - default: - MS_LOG(EXCEPTION) << "gpu IndexAdd op unknown error"; - } - } - size_t dst_size_; size_t index_size_; size_t src_size_; @@ -144,7 +117,6 @@ class IndexAddGpuKernel : public GpuKernel { size_t dst_axis_size_; size_t inner_size_; bool use_lock_; - bool check_index_bound_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/tests/st/ops/gpu/test_index_add_op.py b/tests/st/ops/gpu/test_index_add_op.py index c3145d0a977..8cf53273f5c 100644 --- a/tests/st/ops/gpu/test_index_add_op.py +++ b/tests/st/ops/gpu/test_index_add_op.py @@ -255,13 +255,6 @@ def test_index_add_invalid_inputs(): net = NetIndexAdd(x, 1) _ = net(Tensor(idx), Tensor(y)) - with pytest.raises(RuntimeError) as info: - #index value not in the range of 0 to len(x[axis]) - idx = np.array([5, 6]).astype(np.int32) - net = NetIndexAdd(x, 1) - _ = net(Tensor(idx), Tensor(y)) - assert "out of range" in str(info.value) - class IndexAddGradNet(nn.Cell): def __init__(self, network):