From 347c9d7f0585094b22f1028a1cbde317ae7b5445 Mon Sep 17 00:00:00 2001 From: hw_hz Date: Tue, 7 Mar 2023 15:29:18 +0800 Subject: [PATCH] scatternd bug fixed. --- .../kernel/arrays/scatter_nd_gpu_kernel.cc | 52 ++++++++++--------- .../gpu/kernel/arrays/scatter_nd_gpu_kernel.h | 3 -- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_nd_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_nd_gpu_kernel.cc index e89ac47bdfc..cbea48c4c8c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_nd_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_nd_gpu_kernel.cc @@ -72,45 +72,47 @@ bool ScatterNdGpuKernelMod::LaunchKernel(const std::vector &inputs, S *indices = GetDeviceAddress(inputs, 0); T *update = GetDeviceAddress(inputs, 1); T *output = GetDeviceAddress(outputs, 0); - S *indifes_stride = GetDeviceAddress(workspace, kIndex0); + S *indices_stride = GetDeviceAddress(workspace, kIndex0); S *work_shape = GetDeviceAddress(workspace, kIndex1); - if (!memcpy_flag_) { - const size_t indices_len = sizeof(S) * vec_indices_stride_.size(); - const size_t vec_work_len = sizeof(S) * attr_shape_.size(); - std::vector tmp_ind_stride; - (void)std::transform(vec_indices_stride_.begin(), vec_indices_stride_.end(), std::back_inserter(tmp_ind_stride), - [](size_t x) { return static_cast(x); }); - std::vector tmp_work_shape; - (void)std::transform(attr_shape_.begin(), attr_shape_.end(), std::back_inserter(tmp_work_shape), - [](int64_t x) { return static_cast(x); }); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(indifes_stride, &tmp_ind_stride[0], indices_len, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr_)), - "cudaMemcpy for indices_stride failed in " - "ScatterNdGpuKernelMod::LaunchKernel."); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( - cudaMemcpyAsync(work_shape, &tmp_work_shape[0], vec_work_len, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr_)), - "cudaMemcpy for work_shape failed in " - "ScatterNdGpuKernelMod::LaunchKernel."); - memcpy_flag_ = true; - } + const size_t indices_len = sizeof(S) * vec_indices_stride_.size(); + const size_t vec_work_len = sizeof(S) * attr_shape_.size(); + std::vector tmp_ind_stride; + (void)std::transform(vec_indices_stride_.begin(), vec_indices_stride_.end(), std::back_inserter(tmp_ind_stride), + [](size_t x) { return static_cast(x); }); + std::vector tmp_work_shape; + (void)std::transform(attr_shape_.begin(), attr_shape_.end(), std::back_inserter(tmp_work_shape), + [](int64_t x) { return static_cast(x); }); + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(indices_stride, &tmp_ind_stride[0], indices_len, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr_)), + "cudaMemcpy for indices_stride failed in " + "ScatterNdGpuKernelMod::LaunchKernel."); + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(work_shape, &tmp_work_shape[0], vec_work_len, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr_)), + "cudaMemcpy for work_shape failed in " + "ScatterNdGpuKernelMod::LaunchKernel."); CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( cudaMemsetAsync(output, static_cast(0.0), output_size_list_[0], reinterpret_cast(stream_ptr_)), "cudaMemSet failed in ScatterNdGpuKernelMod::LaunchKernel."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), + "cudaStreamSynchronized failed"); + const size_t input_size = input_size_list_[kIndex1] / sizeof(T); const size_t output_size = output_size_list_[kIndex0] / sizeof(T); - ScatterNd(indices, update, output, block_size_, input_size, output_size, indices_dim_0_, indices_dim_1_, - indifes_stride, work_shape, reinterpret_cast(stream_ptr_)); + indices_stride, work_shape, reinterpret_cast(stream_ptr_)); return true; } bool ScatterNdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); if (!MatchKernelFunc(base_operator, inputs, outputs)) { return false; } @@ -124,7 +126,7 @@ int ScatterNdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) { return ret; } - memcpy_flag_ = false; + if (!TryGetIntValue(inputs, kShapeIndex_, kernel_name_, &attr_shape_)) { MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!"; return KRET_RESIZE_FAILED; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_nd_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_nd_gpu_kernel.h index 164ced13595..bda283c737e 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_nd_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_nd_gpu_kernel.h @@ -64,9 +64,6 @@ class ScatterNdGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelpe size_t block_size_{1}; size_t indices_dim_0_{0}; size_t indices_dim_1_{0}; - - // memory in device - bool memcpy_flag_{false}; }; } // namespace kernel } // namespace mindspore