From ec4f7e7691dd79eabb3e2128a34b357f04a164c3 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 2 Feb 2023 23:14:11 +0800 Subject: [PATCH] sparsedenseadd sparse tensor add 2 3 4 6 7 8 --- .../cuda_ops/sparse_tensor_dense_add_impl.cu | 2 -- .../sparse_tensor_dense_add_gpu_kernel.cc | 33 +++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_dense_add_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_dense_add_impl.cu index 146a4d5557b..e93592786a7 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_dense_add_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_tensor_dense_add_impl.cu @@ -42,8 +42,6 @@ __global__ void SparseTensorDenseAddKernelFunc(size_t input_elements, size_t ran int out_index = 0; for (size_t j = 0; j < rank; j++) { int index = x1_indices_addr[pos * rank + j]; - CUDA_KERNEL_ASSERT(x2_shape[j] == x1_shape_addr[j] && "The input x1_shape does not equal x2_shape!"); - CUDA_KERNEL_ASSERT(index < x1_shape_addr[j] && "The input x1_indices is out of bounds!"); int count = 1; for (size_t k = j + 1; k < rank; k++) { count *= x1_shape_addr[k]; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_tensor_dense_add_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_tensor_dense_add_gpu_kernel.cc index 25755e3a25e..b9249ca8ab3 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_tensor_dense_add_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_tensor_dense_add_gpu_kernel.cc @@ -203,6 +203,39 @@ bool SparseTensorDenseAddGpuKernelMod::LaunchKernel(const std::vector(cuda_stream_)), "cudaMemcpyAsync x2_shape failed"); + constexpr int X1_SHAPE_INDICES = 2; + std::vector x1_shape(inputs[X1_SHAPE_INDICES]->size / sizeof(I)); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(x1_shape.data(), x1_shape_addr, inputs[X1_SHAPE_INDICES]->size, cudaMemcpyDeviceToHost, + reinterpret_cast(cuda_stream_)), + "cudaMemcpyAsync x1_shape failed"); + + std::vector x1_indices_host(inputs[0]->size / sizeof(I)); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(x1_indices_host.data(), x1_indices_addr, inputs[0]->size, cudaMemcpyDeviceToHost, + reinterpret_cast(cuda_stream_)), + "cudaMemcpyAsync x1_indices failed"); + + if (x1_shape.size() != x2_shape_.size()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << " The input x1_shape size does not equal x2_shape size! " + << "tensor shape of 'sparse': " << x1_shape.size() + << ",and the tensor shape of 'dense':" << x2_shape_.size(); + return false; + } + + for (size_t idx = 0; idx < x2_shape_.size(); ++idx) { + if (x1_shape[idx] != x2_shape_[idx]) { + MS_LOG(ERROR) << "For '" << kernel_name_ << " The input x1_shape dim does not equal x2_shape dim! " + << "tensor dim of 'sparse': " << x1_shape[idx] + << ",and the tensor dim of 'dense':" << x2_shape_[idx]; + return false; + } + if (x1_indices_host[idx] >= x1_shape[idx]) { + MS_LOG(ERROR) << "For '" << kernel_name_ << " The input x1_indices is out of bounds! " + << "x1_indices is : " << x1_indices_host[idx] << ", tensor bounds is:" << x1_shape[idx]; + return false; + } + } CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( cudaMemcpyAsync(y_addr, x2_values_addr, output_elements_ * sizeof(T), cudaMemcpyDeviceToDevice, reinterpret_cast(cuda_stream_)),