!49915 ScatterND GPU bug fix

Merge pull request !49915 from haozhang/fix_scatter_nd
This commit is contained in:
i-robot 2023-03-08 09:27:01 +00:00 committed by Gitee
commit 8d7a508099
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 27 additions and 28 deletions

View File

@ -72,45 +72,47 @@ bool ScatterNdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
S *indices = GetDeviceAddress<S>(inputs, 0);
T *update = GetDeviceAddress<T>(inputs, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
S *indifes_stride = GetDeviceAddress<S>(workspace, kIndex0);
S *indices_stride = GetDeviceAddress<S>(workspace, kIndex0);
S *work_shape = GetDeviceAddress<S>(workspace, kIndex1);
if (!memcpy_flag_) {
const size_t indices_len = sizeof(S) * vec_indices_stride_.size();
const size_t vec_work_len = sizeof(S) * attr_shape_.size();
std::vector<S> tmp_ind_stride;
(void)std::transform(vec_indices_stride_.begin(), vec_indices_stride_.end(), std::back_inserter(tmp_ind_stride),
[](size_t x) { return static_cast<S>(x); });
std::vector<S> tmp_work_shape;
(void)std::transform(attr_shape_.begin(), attr_shape_.end(), std::back_inserter(tmp_work_shape),
[](int64_t x) { return static_cast<S>(x); });
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(indifes_stride, &tmp_ind_stride[0], indices_len, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpy for indices_stride failed in "
"ScatterNdGpuKernelMod::LaunchKernel.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(work_shape, &tmp_work_shape[0], vec_work_len, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpy for work_shape failed in "
"ScatterNdGpuKernelMod::LaunchKernel.");
memcpy_flag_ = true;
}
const size_t indices_len = sizeof(S) * vec_indices_stride_.size();
const size_t vec_work_len = sizeof(S) * attr_shape_.size();
std::vector<S> tmp_ind_stride;
(void)std::transform(vec_indices_stride_.begin(), vec_indices_stride_.end(), std::back_inserter(tmp_ind_stride),
[](size_t x) { return static_cast<S>(x); });
std::vector<S> tmp_work_shape;
(void)std::transform(attr_shape_.begin(), attr_shape_.end(), std::back_inserter(tmp_work_shape),
[](int64_t x) { return static_cast<S>(x); });
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(indices_stride, &tmp_ind_stride[0], indices_len, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpy for indices_stride failed in "
"ScatterNdGpuKernelMod::LaunchKernel.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(work_shape, &tmp_work_shape[0], vec_work_len, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpy for work_shape failed in "
"ScatterNdGpuKernelMod::LaunchKernel.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemsetAsync(output, static_cast<T>(0.0), output_size_list_[0], reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemSet failed in ScatterNdGpuKernelMod::LaunchKernel.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaStreamSynchronized failed");
const size_t input_size = input_size_list_[kIndex1] / sizeof(T);
const size_t output_size = output_size_list_[kIndex0] / sizeof(T);
ScatterNd(indices, update, output, block_size_, input_size, output_size, indices_dim_0_, indices_dim_1_,
indifes_stride, work_shape, reinterpret_cast<cudaStream_t>(stream_ptr_));
indices_stride, work_shape, reinterpret_cast<cudaStream_t>(stream_ptr_));
return true;
}
bool ScatterNdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
@ -124,7 +126,7 @@ int ScatterNdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
memcpy_flag_ = false;
if (!TryGetIntValue(inputs, kShapeIndex_, kernel_name_, &attr_shape_)) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
return KRET_RESIZE_FAILED;

View File

@ -64,9 +64,6 @@ class ScatterNdGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelpe
size_t block_size_{1};
size_t indices_dim_0_{0};
size_t indices_dim_1_{0};
// memory in device
bool memcpy_flag_{false};
};
} // namespace kernel
} // namespace mindspore