forked from mindspore-Ecosystem/mindspore
add indices validation check for tensor_scatter_XXX on gpu backend
This commit is contained in:
parent
018536039f
commit
d08ce74f6f
|
@ -157,11 +157,81 @@ int TensorScatterArithmeticGpuKernelMod::Resize(const BaseOperatorPtr &base_oper
|
|||
<< "', the memory alloc of work_shape_ must be successful, but failed, got size: "
|
||||
<< vec_work_len;
|
||||
}
|
||||
|
||||
const auto indices_rank = indices_shape.size();
|
||||
const auto last_indices_value = LongToSize(indices_shape.back());
|
||||
const auto update_rank = update_shape_.size();
|
||||
constexpr size_t min_indices_rank = 2;
|
||||
slice_size_ = last_indices_value;
|
||||
batch_size_ = 1;
|
||||
inner_size_ = 1;
|
||||
|
||||
for (size_t i = 0; i < update_rank; ++i) {
|
||||
if (i <= indices_rank - min_indices_rank) {
|
||||
batch_size_ *= LongToSize(indices_shape[i]);
|
||||
} else {
|
||||
inner_size_ *= LongToSize(update_shape_[i]);
|
||||
}
|
||||
}
|
||||
batch_strides_.resize(slice_size_);
|
||||
for (auto i = SizeToLong(slice_size_) - 1; i >= 0; --i) {
|
||||
auto dim = LongToSize(i);
|
||||
total_batch_size_ *= input_shape_[dim];
|
||||
if (dim == slice_size_ - 1) {
|
||||
batch_strides_[dim] = 1;
|
||||
} else {
|
||||
batch_strides_[dim] = batch_strides_[dim + 1] * input_shape_[dim + 1];
|
||||
}
|
||||
}
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
template <typename S>
|
||||
void TensorScatterArithmeticGpuKernelMod::CheckIndicesValid(S *indices) {
|
||||
size_t total_indices_num =
|
||||
std::accumulate(indices_shape_.begin(), indices_shape_.end(), 1, std::multiplies<size_t>());
|
||||
size_t total_indices_bytes = total_indices_num * indices_unit_size_;
|
||||
S *host_indices = reinterpret_cast<S *>(malloc(total_indices_bytes));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(host_indices, indices, total_indices_bytes, cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"TensorScatterArithmeticGpuKernelMod cudaMemcpy failed in TensorScatterArithmeticGpuKernelMod::CheckIndexValid.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaStreamSynchronized failed");
|
||||
int64_t invalid_index_pos = -1;
|
||||
for (size_t i = 0; i < batch_size_; ++i) {
|
||||
size_t out_index = 0;
|
||||
for (size_t j = 0; j < slice_size_; ++j) {
|
||||
S idx_index = host_indices[SizeToLong(i) * slice_size_ + SizeToLong(j)];
|
||||
out_index += batch_strides_[j] * static_cast<size_t>(idx_index);
|
||||
if (idx_index < 0 || idx_index >= static_cast<S>(input_shape_[j])) {
|
||||
invalid_index_pos = SizeToLong(i * slice_size_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (invalid_index_pos != -1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (invalid_index_pos != -1) {
|
||||
std::stringstream indices_ss;
|
||||
std::stringstream input_shape_ss;
|
||||
for (size_t i = 0; i < slice_size_; i++) {
|
||||
if (i > 0) {
|
||||
indices_ss << ", ";
|
||||
input_shape_ss << ", ";
|
||||
}
|
||||
indices_ss << std::to_string(host_indices[LongToSize(invalid_index_pos) + i]);
|
||||
input_shape_ss << std::to_string(input_shape_[i]);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the " << invalid_index_pos << "-th value of 'indices'["
|
||||
<< indices_ss.str() << "] is out of range[" << input_shape_ss.str() << "].";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
bool TensorScatterArithmeticGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
|
@ -172,6 +242,8 @@ bool TensorScatterArithmeticGpuKernelMod::LaunchKernel(const std::vector<Address
|
|||
T *update = GetDeviceAddress<T>(inputs, kIndex2);
|
||||
T *output = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
|
||||
(void)CheckIndicesValid(indices);
|
||||
|
||||
if (!memcpy_flag_) {
|
||||
const size_t indices_len = indices_unit_size_ * vec_indices_stride_.size();
|
||||
std::vector<S> vec_indices_stride_s = std::vector<S>(vec_indices_stride_.begin(), vec_indices_stride_.end());
|
||||
|
|
|
@ -57,6 +57,8 @@ class TensorScatterArithmeticGpuKernelMod : public NativeGpuKernelMod,
|
|||
void FreeResource();
|
||||
bool GetOpType(const BaseOperatorPtr &base_operator);
|
||||
void UpdateSize();
|
||||
template <typename S>
|
||||
void CheckIndicesValid(S *indices);
|
||||
template <typename T, typename S>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
@ -81,6 +83,11 @@ class TensorScatterArithmeticGpuKernelMod : public NativeGpuKernelMod,
|
|||
void *indices_stride_{nullptr};
|
||||
void *work_shape_{nullptr};
|
||||
void *stream_ptr_{nullptr};
|
||||
size_t slice_size_{1};
|
||||
size_t batch_size_{1};
|
||||
size_t inner_size_{1};
|
||||
size_t total_batch_size_{1};
|
||||
std::vector<size_t> batch_strides_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,7 +36,7 @@ def test_list_getitem_eliminate():
|
|||
Description: Test list_getitem not be replaced as TupleGetItem in pass 'item_tuple_or_list_eliminate'
|
||||
Expectation: No exception.
|
||||
"""
|
||||
max_iter = Tensor([2])
|
||||
max_iter = Tensor([1])
|
||||
k = Tensor([0])
|
||||
inputs = [Tensor([4]), Tensor([5])]
|
||||
net = NetWork()
|
||||
|
|
Loading…
Reference in New Issue