diff --git a/docs/api/api_python/ops/mindspore.ops.TensorScatterUpdate.rst b/docs/api/api_python/ops/mindspore.ops.TensorScatterUpdate.rst index d0715db5500..8ca460a209c 100644 --- a/docs/api/api_python/ops/mindspore.ops.TensorScatterUpdate.rst +++ b/docs/api/api_python/ops/mindspore.ops.TensorScatterUpdate.rst @@ -20,4 +20,5 @@ 异常: - **TypeError** - `indices` 的数据类型既不是int32,也不是int64。 - **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。 - - **ValueError** - `input_x` 的值与输入 `indices` 不匹配。 \ No newline at end of file + - **ValueError** - `input_x` 的值与输入 `indices` 不匹配。 + - **RuntimeError** - `indices` 超出了 `input_x` 的索引范围。 \ No newline at end of file diff --git a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_add.rst b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_add.rst index dc2a477f766..5ec5d5f12f2 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_add.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_add.rst @@ -21,3 +21,4 @@ 异常: - **TypeError** - `indices` 的数据类型既不是int32,也不是int64。 - **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。 + - **RuntimeError** - `indices` 超出了 `input_x` 的索引范围。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_div.rst b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_div.rst index 2d2d3c9c626..196234f4b36 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_div.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_div.rst @@ -22,3 +22,4 @@ mindspore.ops.tensor_scatter_div 异常: - **TypeError** - `indices` 的数据类型既不是int32,也不是int64。 - **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。 + - **RuntimeError** - `indices` 超出了 `input_x` 的索引范围。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_max.rst b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_max.rst index f2d2bd53e69..581d692d0e8 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_max.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_max.rst @@ -21,3 +21,4 @@ 异常: - **TypeError** - `indices` 的数据类型既不是int32,也不是int64。 - **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。 + - **RuntimeError** - `indices` 超出了 `input_x` 的索引范围。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_min.rst b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_min.rst index 541a5fd19da..637a08a727e 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_min.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_min.rst @@ -21,3 +21,4 @@ 异常: - **TypeError** - `indices` 的数据类型既不是int32,也不是int64。 - **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。 + - **RuntimeError** - `indices` 超出了 `input_x` 的索引范围。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_mul.rst b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_mul.rst index a4c56a88538..20c847ad6b4 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_mul.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_mul.rst @@ -21,3 +21,4 @@ mindspore.ops.tensor_scatter_mul 异常: - **TypeError** - `indices` 的数据类型既不是int32,也不是int64。 - **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。 + - **RuntimeError** - `indices` 超出了 `input_x` 的索引范围。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_sub.rst b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_sub.rst index 4920812f854..b04ec487f77 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_sub.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_tensor_scatter_sub.rst @@ -21,3 +21,4 @@ 异常: - **TypeError** - `indices` 的数据类型既不是int32,也不是int64。 - **ValueError** - `input_x` 的shape长度小于 `indices` 的shape的最后一个维度。 + - **RuntimeError** - `indices` 超出了 `input_x` 的索引范围。 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_scatter_arithmetic_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_scatter_arithmetic_gpu_kernel.cc index 48d828117c2..ed8118c5400 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_scatter_arithmetic_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_scatter_arithmetic_gpu_kernel.cc @@ -190,7 +190,19 @@ template using Complex = mindspore::utils::Complex; template -void TensorScatterArithmeticGpuKernelMod::CheckIndicesValid(S *indices) { +void TensorScatterArithmeticGpuKernelMod::CheckIndicesValid(int *has_error, S *indices) { + // detect errors + int *has_error_host = reinterpret_cast(malloc(sizeof(int))); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(has_error_host, has_error, sizeof(int), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr_)), + "TensorScatterArithmeticGpuKernelMod cudaMemcpy failed in TensorScatterArithmeticGpuKernelMod::CheckIndicesValid."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), + "cudaStreamSynchronized failed"); + if (has_error_host[0] != 1) { + return; + } + size_t total_indices_num = std::accumulate(indices_shape_.begin(), indices_shape_.end(), 1, std::multiplies()); size_t total_indices_bytes = total_indices_num * indices_unit_size_; @@ -242,7 +254,16 @@ bool TensorScatterArithmeticGpuKernelMod::LaunchKernel(const std::vector
(inputs, kIndex2); T *output = GetDeviceAddress(outputs, kIndex0); - (void)CheckIndicesValid(indices); + // set a flag to detect errors + int has_error_host[1] = {0}; + int *has_error = nullptr; + cudaMalloc(reinterpret_cast(&has_error), sizeof(int)); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(has_error, has_error_host, sizeof(int), cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr_)), + "TensorScatterArithmeticGpuKernelMod cudaMemcpy failed in TensorScatterArithmeticGpuKernelMod::LaunchKernel."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), + "cudaStreamSynchronized failed"); if (!memcpy_flag_) { const size_t indices_len = indices_unit_size_ * vec_indices_stride_.size(); @@ -268,17 +289,20 @@ bool TensorScatterArithmeticGpuKernelMod::LaunchKernel(const std::vector
>) || (std::is_same_v>)) { if (kernel_name_ == kTensorScatterUpdate) { - CallTensorScatterUpdate(input, indices, update, output, block_size_, update_size_, output_size_, indices_dim_0_, - indices_dim_1_, reinterpret_cast(indices_stride_), + CallTensorScatterUpdate(input, indices, update, output, has_error, block_size_, update_size_, output_size_, + indices_dim_0_, indices_dim_1_, reinterpret_cast(indices_stride_), reinterpret_cast(work_shape_), device_id_, reinterpret_cast(stream_ptr_)); + + (void)CheckIndicesValid(has_error, indices); return true; } } else { - TensorScatterArithmetic(op_func_type_, input, indices, update, output, block_size_, update_size_, output_size_, - indices_dim_0_, indices_dim_1_, reinterpret_cast(indices_stride_), + TensorScatterArithmetic(op_func_type_, input, indices, update, output, has_error, block_size_, update_size_, + output_size_, indices_dim_0_, indices_dim_1_, reinterpret_cast(indices_stride_), reinterpret_cast(work_shape_), device_id_, reinterpret_cast(stream_ptr_)); + (void)CheckIndicesValid(has_error, indices); } return true; } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_scatter_arithmetic_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_scatter_arithmetic_gpu_kernel.h index 5973064a679..e41897d00b0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_scatter_arithmetic_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tensor_scatter_arithmetic_gpu_kernel.h @@ -58,7 +58,7 @@ class TensorScatterArithmeticGpuKernelMod : public NativeGpuKernelMod, bool GetOpType(const BaseOperatorPtr &base_operator); void UpdateSize(); template - void CheckIndicesValid(S *indices); + void CheckIndicesValid(int *has_error, S *indices); template bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cu index 274fac71b13..7e71fe879cb 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cu @@ -18,7 +18,7 @@ #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" template -__global__ void TensorScatterUpdateKernel(const T *input, const S *indices, const T *update, T *output, +__global__ void TensorScatterUpdateKernel(const T *input, const S *indices, const T *update, T *output, int *has_error, const size_t block_size, const size_t input_size, const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride, S *work_shape) { @@ -41,12 +41,14 @@ __global__ void TensorScatterUpdateKernel(const T *input, const S *indices, cons out_bound |= write_index >= output_size; if (!out_bound) { output[write_index] = update[read_index]; + } else { + has_error[0] = true; } } } template -__global__ void TensorScatterMinKernel(const T *input, const S *indices, const T *update, T *output, +__global__ void TensorScatterMinKernel(const T *input, const S *indices, const T *update, T *output, int *has_error, const size_t block_size, const size_t input_size, const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride, S *work_shape) { @@ -69,12 +71,14 @@ __global__ void TensorScatterMinKernel(const T *input, const S *indices, const T out_bound |= write_index >= output_size; if (!out_bound) { (void)MsAtomicMin(&output[write_index], update[read_index]); + } else { + has_error[0] = true; } } } template -__global__ void TensorScatterMaxKernel(const T *input, const S *indices, const T *update, T *output, +__global__ void TensorScatterMaxKernel(const T *input, const S *indices, const T *update, T *output, int *has_error, const size_t block_size, const size_t input_size, const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride, S *work_shape) { @@ -97,12 +101,14 @@ __global__ void TensorScatterMaxKernel(const T *input, const S *indices, const T out_bound |= write_index >= output_size; if (!out_bound) { (void)MsAtomicMax(&output[write_index], update[read_index]); + } else { + has_error[0] = true; } } } template -__global__ void TensorScatterAddKernel(const T *input, const S *indices, const T *update, T *output, +__global__ void TensorScatterAddKernel(const T *input, const S *indices, const T *update, T *output, int *has_error, const size_t block_size, const size_t input_size, const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride, S *work_shape) { @@ -125,12 +131,14 @@ __global__ void TensorScatterAddKernel(const T *input, const S *indices, const T out_bound |= write_index >= output_size; if (!out_bound) { (void)MsAtomicAdd(&output[write_index], update[read_index]); + } else { + has_error[0] = true; } } } template -__global__ void TensorScatterSubKernel(const T *input, const S *indices, const T *update, T *output, +__global__ void TensorScatterSubKernel(const T *input, const S *indices, const T *update, T *output, int *has_error, const size_t block_size, const size_t input_size, const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride, S *work_shape) { @@ -153,12 +161,14 @@ __global__ void TensorScatterSubKernel(const T *input, const S *indices, const T out_bound |= write_index >= output_size; if (!out_bound) { (void)MsAtomicSub(&output[write_index], update[read_index]); + } else { + has_error[0] = true; } } } template -__global__ void TensorScatterMulKernel(const T *input, const S *indices, const T *update, T *output, +__global__ void TensorScatterMulKernel(const T *input, const S *indices, const T *update, T *output, int *has_error, const size_t block_size, const size_t input_size, const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride, S *work_shape) { @@ -181,12 +191,14 @@ __global__ void TensorScatterMulKernel(const T *input, const S *indices, const T out_bound |= write_index >= output_size; if (!out_bound) { (void)MsAtomicMul(&output[write_index], update[read_index]); + } else { + has_error[0] = true; } } } template -__global__ void TensorScatterDivKernel(const T *input, const S *indices, const T *update, T *output, +__global__ void TensorScatterDivKernel(const T *input, const S *indices, const T *update, T *output, int *has_error, const size_t block_size, const size_t input_size, const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, S *indices_stride, S *work_shape) { @@ -209,44 +221,46 @@ __global__ void TensorScatterDivKernel(const T *input, const S *indices, const T out_bound |= write_index >= output_size; if (!out_bound) { (void)MsAtomicDiv(&output[write_index], update[read_index]); + } else { + has_error[0] = true; } } } template void TensorScatterArithmetic(const enum TensorScatterArithmeticFunctionType &func_type, const T *input, - const S *indices, const T *update, T *output, const size_t &block_size, + const S *indices, const T *update, T *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, S *work_shape, uint32_t device_id, cudaStream_t stream) { switch (func_type) { case TENSOR_SCATTER_FUNC_UPDATE: return TensorScatterUpdateKernel<<>>( - input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, + input, indices, update, output, has_error, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride, work_shape); case TENSOR_SCATTER_FUNC_MIN: return TensorScatterMinKernel<<>>( - input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, + input, indices, update, output, has_error, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride, work_shape); case TENSOR_SCATTER_FUNC_MAX: return TensorScatterMaxKernel<<>>( - input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, + input, indices, update, output, has_error, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride, work_shape); case TENSOR_SCATTER_FUNC_ADD: return TensorScatterAddKernel<<>>( - input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, + input, indices, update, output, has_error, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride, work_shape); case TENSOR_SCATTER_FUNC_SUB: return TensorScatterSubKernel<<>>( - input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, + input, indices, update, output, has_error, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride, work_shape); case TENSOR_SCATTER_FUNC_MUL: return TensorScatterMulKernel<<>>( - input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, + input, indices, update, output, has_error, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride, work_shape); case TENSOR_SCATTER_FUNC_DIV: return TensorScatterDivKernel<<>>( - input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, + input, indices, update, output, has_error, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride, work_shape); default: break; @@ -254,175 +268,178 @@ void TensorScatterArithmetic(const enum TensorScatterArithmeticFunctionType &fun } template -void CallTensorScatterUpdate(const T *input, const S *indices, const T *update, T *output, const size_t &block_size, - const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, - const size_t &indices_dim_1, S *indices_stride, S *work_shape, uint32_t device_id, - cudaStream_t stream) { +void CallTensorScatterUpdate(const T *input, const S *indices, const T *update, T *output, int *has_error, + const size_t &block_size, const size_t &input_size, const size_t &output_size, + const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, S *work_shape, + uint32_t device_id, cudaStream_t stream) { TensorScatterUpdateKernel<<>>( - input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride, - work_shape); + input, indices, update, output, has_error, block_size, input_size, output_size, indices_dim_0, indices_dim_1, + indices_stride, work_shape); } template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const half *input, const int *indices, const half *update, - half *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, + half *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const float *input, const int *indices, - const float *update, float *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, - cudaStream_t stream); + const float *update, float *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, + int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const double *input, const int *indices, - const double *update, double *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, - cudaStream_t stream); + const double *update, double *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, + int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const char *input, const int *indices, const char *update, - char *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, + char *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const unsigned char *input, const int *indices, - const unsigned char *update, unsigned char *output, const size_t &block_size, const size_t &input_size, - const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, - int *work_shape, uint32_t device_id, cudaStream_t stream); + const unsigned char *update, unsigned char *output, int *has_error, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, + int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const int16_t *input, const int *indices, - const int16_t *update, int16_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, - cudaStream_t stream); + const int16_t *update, int16_t *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, + int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const uint16_t *input, const int *indices, - const uint16_t *update, uint16_t *output, const size_t &block_size, const size_t &input_size, + const uint16_t *update, uint16_t *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const int *input, const int *indices, const int *update, - int *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, + int *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const uint32_t *input, const int *indices, - const uint32_t *update, uint32_t *output, const size_t &block_size, const size_t &input_size, + const uint32_t *update, uint32_t *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const int64_t *input, const int *indices, - const int64_t *update, int64_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, - cudaStream_t stream); + const int64_t *update, int64_t *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, + int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const uint64_t *input, const int *indices, - const uint64_t *update, uint64_t *output, const size_t &block_size, const size_t &input_size, + const uint64_t *update, uint64_t *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const bool *input, const int *indices, const bool *update, - bool *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, + bool *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const half *input, const int64_t *indices, - const half *update, half *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, - uint32_t device_id, cudaStream_t stream); - -template CUDA_LIB_EXPORT void TensorScatterArithmetic( - const enum TensorScatterArithmeticFunctionType &func_type, const float *input, const int64_t *indices, - const float *update, float *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, - uint32_t device_id, cudaStream_t stream); - -template CUDA_LIB_EXPORT void TensorScatterArithmetic( - const enum TensorScatterArithmeticFunctionType &func_type, const double *input, const int64_t *indices, - const double *update, double *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, - uint32_t device_id, cudaStream_t stream); - -template CUDA_LIB_EXPORT void TensorScatterArithmetic( - const enum TensorScatterArithmeticFunctionType &func_type, const char *input, const int64_t *indices, - const char *update, char *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, - uint32_t device_id, cudaStream_t stream); - -template CUDA_LIB_EXPORT void TensorScatterArithmetic( - const enum TensorScatterArithmeticFunctionType &func_type, const unsigned char *input, const int64_t *indices, - const unsigned char *update, unsigned char *output, const size_t &block_size, const size_t &input_size, + const half *update, half *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream); +template CUDA_LIB_EXPORT void TensorScatterArithmetic( + const enum TensorScatterArithmeticFunctionType &func_type, const float *input, const int64_t *indices, + const float *update, float *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, + int64_t *work_shape, uint32_t device_id, cudaStream_t stream); + +template CUDA_LIB_EXPORT void TensorScatterArithmetic( + const enum TensorScatterArithmeticFunctionType &func_type, const double *input, const int64_t *indices, + const double *update, double *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, + int64_t *work_shape, uint32_t device_id, cudaStream_t stream); + +template CUDA_LIB_EXPORT void TensorScatterArithmetic( + const enum TensorScatterArithmeticFunctionType &func_type, const char *input, const int64_t *indices, + const char *update, char *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, + int64_t *work_shape, uint32_t device_id, cudaStream_t stream); + +template CUDA_LIB_EXPORT void TensorScatterArithmetic( + const enum TensorScatterArithmeticFunctionType &func_type, const unsigned char *input, const int64_t *indices, + const unsigned char *update, unsigned char *output, int *has_error, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, + int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream); + template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const int16_t *input, const int64_t *indices, - const int16_t *update, int16_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, - uint32_t device_id, cudaStream_t stream); + const int16_t *update, int16_t *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, + int64_t *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const uint16_t *input, const int64_t *indices, - const uint16_t *update, uint16_t *output, const size_t &block_size, const size_t &input_size, + const uint16_t *update, uint16_t *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const int *input, const int64_t *indices, - const int *update, int *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, - uint32_t device_id, cudaStream_t stream); + const int *update, int *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, + int64_t *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const uint32_t *input, const int64_t *indices, - const uint32_t *update, uint32_t *output, const size_t &block_size, const size_t &input_size, + const uint32_t *update, uint32_t *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const int64_t *input, const int64_t *indices, - const int64_t *update, int64_t *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, - uint32_t device_id, cudaStream_t stream); + const int64_t *update, int64_t *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, + int64_t *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const uint64_t *input, const int64_t *indices, - const uint64_t *update, uint64_t *output, const size_t &block_size, const size_t &input_size, + const uint64_t *update, uint64_t *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void TensorScatterArithmetic( const enum TensorScatterArithmeticFunctionType &func_type, const bool *input, const int64_t *indices, - const bool *update, bool *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, - uint32_t device_id, cudaStream_t stream); + const bool *update, bool *output, int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, + int64_t *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void CallTensorScatterUpdate, int64_t>( const Complex *input, const int64_t *indices, const Complex *update, Complex *output, - const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, - const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream); + int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, + const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, + uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void CallTensorScatterUpdate, int>( - const Complex *input, const int *indices, const Complex *update, Complex *output, + const Complex *input, const int *indices, const Complex *update, Complex *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void CallTensorScatterUpdate, int64_t>( const Complex *input, const int64_t *indices, const Complex *update, Complex *output, - const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, - const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, uint32_t device_id, cudaStream_t stream); + int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, + const size_t &indices_dim_0, const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape, + uint32_t device_id, cudaStream_t stream); template CUDA_LIB_EXPORT void CallTensorScatterUpdate, int>( const Complex *input, const int *indices, const Complex *update, Complex *output, - const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, - const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, cudaStream_t stream); + int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, + const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, uint32_t device_id, + cudaStream_t stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cuh index 3815c376a0e..e81e9fc13c9 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cuh @@ -31,14 +31,15 @@ enum TensorScatterArithmeticFunctionType { template CUDA_LIB_EXPORT void TensorScatterArithmetic(const enum TensorScatterArithmeticFunctionType &func_type, const T *input, - const S *indices, const T *update, T *output, const size_t &block_size, - const size_t &input_size, const size_t &output_size, - const size_t &indices_dim_0, const size_t &indices_dim_1, - S *indices_stride, S *work_shape, uint32_t device_id, cudaStream_t stream); -template -CUDA_LIB_EXPORT void CallTensorScatterUpdate(const T *input, const S *indices, const T *update, T *output, + const S *indices, const T *update, T *output, int *has_error, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, S *work_shape, uint32_t device_id, cudaStream_t stream); +template +CUDA_LIB_EXPORT void CallTensorScatterUpdate(const T *input, const S *indices, const T *update, T *output, + int *has_error, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, S *indices_stride, S *work_shape, + uint32_t device_id, cudaStream_t stream); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TENSOR_SCATTER_ARITHMETIC_CUH_ diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.cc index 414edd66d20..dff0835907a 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.cc @@ -101,12 +101,18 @@ int TensorScatterAddPlugin::RunCudaTensorScatterAdd(const nvinfer1::PluginTensor cudaMalloc(&input_shape_dptr, input_dims.nbDims * sizeof(int)); cudaMemcpy(input_shape_dptr, input_dims.d, input_dims.nbDims * sizeof(int), cudaMemcpyHostToDevice); + // a flag for error detect + int flag_host[1] = {0}; + int *flag = nullptr; + cudaMalloc(reinterpret_cast(&flag), sizeof(int)); + cudaMemcpy(flag, flag_host, sizeof(int), cudaMemcpyHostToDevice); + cudaMemcpy(outputs[0], inputs[0], input_num * sizeof(float), cudaMemcpyDeviceToDevice); TensorScatterArithmetic(TensorScatterArithmeticFunctionType::TENSOR_SCATTER_FUNC_ADD, static_cast(inputs[0]), static_cast(inputs[1]), - static_cast(inputs[INPUT_SIZE2]), static_cast(outputs[0]), block_size, - update_num, input_num, indice_dim_0, indice_dim_1, indice_stride_dptr, input_shape_dptr, - device_id_, stream); + static_cast(inputs[INPUT_SIZE2]), static_cast(outputs[0]), flag, + block_size, update_num, input_num, indice_dim_0, indice_dim_1, indice_stride_dptr, + input_shape_dptr, device_id_, stream); cudaFree(indice_stride_dptr); cudaFree(input_shape_dptr); diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index 6eac9feb9c4..9b65cdf07c8 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -3333,6 +3333,7 @@ def tensor_scatter_add(input_x, indices, updates): Raises: TypeError: If dtype of `indices` is neither int32 nor int64. ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`. + RuntimeError: If a value of `indices` is not in `input_x`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -3385,6 +3386,7 @@ def tensor_scatter_sub(input_x, indices, updates): Raises: TypeError: If dtype of `indices` is neither int32 nor int64. ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`. + RuntimeError: If a value of `indices` is not in `input_x`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -3432,6 +3434,7 @@ def tensor_scatter_max(input_x, indices, updates): Raises: TypeError: If dtype of `indices` is neither int32 nor int64. ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`. + RuntimeError: If a value of `indices` is not in `input_x`. Supported Platforms: ``GPU`` ``CPU`` @@ -3483,6 +3486,7 @@ def tensor_scatter_min(input_x, indices, updates): Raises: TypeError: If dtype of `indices` is neither int32 nor int64. ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`. + RuntimeError: If a value of `indices` is not in `input_x`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -4590,6 +4594,7 @@ def tensor_scatter_mul(input_x, indices, updates): Raises: TypeError: If dtype of `indices` is neither int32 nor int64. ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`. + RuntimeError: If a value of `indices` is not in `input_x`. Supported Platforms: ``GPU`` ``CPU`` @@ -4645,6 +4650,7 @@ def tensor_scatter_div(input_x, indices, updates): Raises: TypeError: If dtype of `indices` is neither int32 nor int64. ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`. + RuntimeError: If a value of `indices` is not in `input_x`. Supported Platforms: ``GPU`` ``CPU`` diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 15b6229d6d0..c9400c7a2c4 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -6253,6 +6253,7 @@ class TensorScatterUpdate(_TensorScatterOp): TypeError: If dtype of `indices` is neither int32 nor int64. ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`. ValueError: If the value of `input_x` are not match with input `indices`. + RuntimeError: If a value of `indices` is not in `input_x`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU``