forked from mindspore-Ecosystem/mindspore
!49915 ScatterND GPU bug fix
Merge pull request !49915 from haozhang/fix_scatter_nd
This commit is contained in:
commit
8d7a508099
|
@ -72,45 +72,47 @@ bool ScatterNdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||||
S *indices = GetDeviceAddress<S>(inputs, 0);
|
S *indices = GetDeviceAddress<S>(inputs, 0);
|
||||||
T *update = GetDeviceAddress<T>(inputs, 1);
|
T *update = GetDeviceAddress<T>(inputs, 1);
|
||||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||||
S *indifes_stride = GetDeviceAddress<S>(workspace, kIndex0);
|
S *indices_stride = GetDeviceAddress<S>(workspace, kIndex0);
|
||||||
S *work_shape = GetDeviceAddress<S>(workspace, kIndex1);
|
S *work_shape = GetDeviceAddress<S>(workspace, kIndex1);
|
||||||
|
|
||||||
if (!memcpy_flag_) {
|
const size_t indices_len = sizeof(S) * vec_indices_stride_.size();
|
||||||
const size_t indices_len = sizeof(S) * vec_indices_stride_.size();
|
const size_t vec_work_len = sizeof(S) * attr_shape_.size();
|
||||||
const size_t vec_work_len = sizeof(S) * attr_shape_.size();
|
std::vector<S> tmp_ind_stride;
|
||||||
std::vector<S> tmp_ind_stride;
|
(void)std::transform(vec_indices_stride_.begin(), vec_indices_stride_.end(), std::back_inserter(tmp_ind_stride),
|
||||||
(void)std::transform(vec_indices_stride_.begin(), vec_indices_stride_.end(), std::back_inserter(tmp_ind_stride),
|
[](size_t x) { return static_cast<S>(x); });
|
||||||
[](size_t x) { return static_cast<S>(x); });
|
std::vector<S> tmp_work_shape;
|
||||||
std::vector<S> tmp_work_shape;
|
(void)std::transform(attr_shape_.begin(), attr_shape_.end(), std::back_inserter(tmp_work_shape),
|
||||||
(void)std::transform(attr_shape_.begin(), attr_shape_.end(), std::back_inserter(tmp_work_shape),
|
[](int64_t x) { return static_cast<S>(x); });
|
||||||
[](int64_t x) { return static_cast<S>(x); });
|
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||||
cudaMemcpyAsync(indifes_stride, &tmp_ind_stride[0], indices_len, cudaMemcpyHostToDevice,
|
cudaMemcpyAsync(indices_stride, &tmp_ind_stride[0], indices_len, cudaMemcpyHostToDevice,
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||||
"cudaMemcpy for indices_stride failed in "
|
"cudaMemcpy for indices_stride failed in "
|
||||||
"ScatterNdGpuKernelMod::LaunchKernel.");
|
"ScatterNdGpuKernelMod::LaunchKernel.");
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
|
||||||
cudaMemcpyAsync(work_shape, &tmp_work_shape[0], vec_work_len, cudaMemcpyHostToDevice,
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
cudaMemcpyAsync(work_shape, &tmp_work_shape[0], vec_work_len, cudaMemcpyHostToDevice,
|
||||||
"cudaMemcpy for work_shape failed in "
|
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||||
"ScatterNdGpuKernelMod::LaunchKernel.");
|
"cudaMemcpy for work_shape failed in "
|
||||||
memcpy_flag_ = true;
|
"ScatterNdGpuKernelMod::LaunchKernel.");
|
||||||
}
|
|
||||||
|
|
||||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||||
cudaMemsetAsync(output, static_cast<T>(0.0), output_size_list_[0], reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
cudaMemsetAsync(output, static_cast<T>(0.0), output_size_list_[0], reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||||
"cudaMemSet failed in ScatterNdGpuKernelMod::LaunchKernel.");
|
"cudaMemSet failed in ScatterNdGpuKernelMod::LaunchKernel.");
|
||||||
|
|
||||||
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||||
|
"cudaStreamSynchronized failed");
|
||||||
|
|
||||||
const size_t input_size = input_size_list_[kIndex1] / sizeof(T);
|
const size_t input_size = input_size_list_[kIndex1] / sizeof(T);
|
||||||
const size_t output_size = output_size_list_[kIndex0] / sizeof(T);
|
const size_t output_size = output_size_list_[kIndex0] / sizeof(T);
|
||||||
|
|
||||||
ScatterNd(indices, update, output, block_size_, input_size, output_size, indices_dim_0_, indices_dim_1_,
|
ScatterNd(indices, update, output, block_size_, input_size, output_size, indices_dim_0_, indices_dim_1_,
|
||||||
indifes_stride, work_shape, reinterpret_cast<cudaStream_t>(stream_ptr_));
|
indices_stride, work_shape, reinterpret_cast<cudaStream_t>(stream_ptr_));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ScatterNdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
bool ScatterNdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
const std::vector<KernelTensorPtr> &outputs) {
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(base_operator);
|
||||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -124,7 +126,7 @@ int ScatterNdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
||||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
memcpy_flag_ = false;
|
|
||||||
if (!TryGetIntValue(inputs, kShapeIndex_, kernel_name_, &attr_shape_)) {
|
if (!TryGetIntValue(inputs, kShapeIndex_, kernel_name_, &attr_shape_)) {
|
||||||
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
|
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
|
||||||
return KRET_RESIZE_FAILED;
|
return KRET_RESIZE_FAILED;
|
||||||
|
|
|
@ -64,9 +64,6 @@ class ScatterNdGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelpe
|
||||||
size_t block_size_{1};
|
size_t block_size_{1};
|
||||||
size_t indices_dim_0_{0};
|
size_t indices_dim_0_{0};
|
||||||
size_t indices_dim_1_{0};
|
size_t indices_dim_1_{0};
|
||||||
|
|
||||||
// memory in device
|
|
||||||
bool memcpy_flag_{false};
|
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue