forked from mindspore-Ecosystem/mindspore
scatternd bug fixed.
This commit is contained in:
parent
f2c7cbb3bc
commit
347c9d7f05
|
@ -72,10 +72,9 @@ bool ScatterNdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
S *indices = GetDeviceAddress<S>(inputs, 0);
|
||||
T *update = GetDeviceAddress<T>(inputs, 1);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
S *indifes_stride = GetDeviceAddress<S>(workspace, kIndex0);
|
||||
S *indices_stride = GetDeviceAddress<S>(workspace, kIndex0);
|
||||
S *work_shape = GetDeviceAddress<S>(workspace, kIndex1);
|
||||
|
||||
if (!memcpy_flag_) {
|
||||
const size_t indices_len = sizeof(S) * vec_indices_stride_.size();
|
||||
const size_t vec_work_len = sizeof(S) * attr_shape_.size();
|
||||
std::vector<S> tmp_ind_stride;
|
||||
|
@ -84,33 +83,36 @@ bool ScatterNdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
std::vector<S> tmp_work_shape;
|
||||
(void)std::transform(attr_shape_.begin(), attr_shape_.end(), std::back_inserter(tmp_work_shape),
|
||||
[](int64_t x) { return static_cast<S>(x); });
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(indifes_stride, &tmp_ind_stride[0], indices_len, cudaMemcpyHostToDevice,
|
||||
cudaMemcpyAsync(indices_stride, &tmp_ind_stride[0], indices_len, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpy for indices_stride failed in "
|
||||
"ScatterNdGpuKernelMod::LaunchKernel.");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(work_shape, &tmp_work_shape[0], vec_work_len, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpy for work_shape failed in "
|
||||
"ScatterNdGpuKernelMod::LaunchKernel.");
|
||||
memcpy_flag_ = true;
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemsetAsync(output, static_cast<T>(0.0), output_size_list_[0], reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemSet failed in ScatterNdGpuKernelMod::LaunchKernel.");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaStreamSynchronized failed");
|
||||
|
||||
const size_t input_size = input_size_list_[kIndex1] / sizeof(T);
|
||||
const size_t output_size = output_size_list_[kIndex0] / sizeof(T);
|
||||
|
||||
ScatterNd(indices, update, output, block_size_, input_size, output_size, indices_dim_0_, indices_dim_1_,
|
||||
indifes_stride, work_shape, reinterpret_cast<cudaStream_t>(stream_ptr_));
|
||||
indices_stride, work_shape, reinterpret_cast<cudaStream_t>(stream_ptr_));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ScatterNdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -124,7 +126,7 @@ int ScatterNdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
memcpy_flag_ = false;
|
||||
|
||||
if (!TryGetIntValue(inputs, kShapeIndex_, kernel_name_, &attr_shape_)) {
|
||||
MS_LOG(EXCEPTION) << "For " << kernel_name_ << "can't get shape input!";
|
||||
return KRET_RESIZE_FAILED;
|
||||
|
|
|
@ -64,9 +64,6 @@ class ScatterNdGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelpe
|
|||
size_t block_size_{1};
|
||||
size_t indices_dim_0_{0};
|
||||
size_t indices_dim_1_{0};
|
||||
|
||||
// memory in device
|
||||
bool memcpy_flag_{false};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue