add indices validation check for tensor_scatter_XXX on gpu backend

This commit is contained in:
yiyanzhi_akane 2023-01-30 16:17:15 +08:00
parent 018536039f
commit d08ce74f6f
3 changed files with 80 additions and 1 deletions

View File

@ -157,11 +157,81 @@ int TensorScatterArithmeticGpuKernelMod::Resize(const BaseOperatorPtr &base_oper
<< "', the memory alloc of work_shape_ must be successful, but failed, got size: "
<< vec_work_len;
}
const auto indices_rank = indices_shape.size();
const auto last_indices_value = LongToSize(indices_shape.back());
const auto update_rank = update_shape_.size();
constexpr size_t min_indices_rank = 2;
slice_size_ = last_indices_value;
batch_size_ = 1;
inner_size_ = 1;
for (size_t i = 0; i < update_rank; ++i) {
if (i <= indices_rank - min_indices_rank) {
batch_size_ *= LongToSize(indices_shape[i]);
} else {
inner_size_ *= LongToSize(update_shape_[i]);
}
}
batch_strides_.resize(slice_size_);
for (auto i = SizeToLong(slice_size_) - 1; i >= 0; --i) {
auto dim = LongToSize(i);
total_batch_size_ *= input_shape_[dim];
if (dim == slice_size_ - 1) {
batch_strides_[dim] = 1;
} else {
batch_strides_[dim] = batch_strides_[dim + 1] * input_shape_[dim + 1];
}
}
return KRET_OK;
}
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename S>
void TensorScatterArithmeticGpuKernelMod::CheckIndicesValid(S *indices) {
size_t total_indices_num =
std::accumulate(indices_shape_.begin(), indices_shape_.end(), 1, std::multiplies<size_t>());
size_t total_indices_bytes = total_indices_num * indices_unit_size_;
S *host_indices = reinterpret_cast<S *>(malloc(total_indices_bytes));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(host_indices, indices, total_indices_bytes, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"TensorScatterArithmeticGpuKernelMod cudaMemcpy failed in TensorScatterArithmeticGpuKernelMod::CheckIndexValid.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaStreamSynchronized failed");
int64_t invalid_index_pos = -1;
for (size_t i = 0; i < batch_size_; ++i) {
size_t out_index = 0;
for (size_t j = 0; j < slice_size_; ++j) {
S idx_index = host_indices[SizeToLong(i) * slice_size_ + SizeToLong(j)];
out_index += batch_strides_[j] * static_cast<size_t>(idx_index);
if (idx_index < 0 || idx_index >= static_cast<S>(input_shape_[j])) {
invalid_index_pos = SizeToLong(i * slice_size_);
break;
}
}
if (invalid_index_pos != -1) {
break;
}
}
if (invalid_index_pos != -1) {
std::stringstream indices_ss;
std::stringstream input_shape_ss;
for (size_t i = 0; i < slice_size_; i++) {
if (i > 0) {
indices_ss << ", ";
input_shape_ss << ", ";
}
indices_ss << std::to_string(host_indices[LongToSize(invalid_index_pos) + i]);
input_shape_ss << std::to_string(input_shape_[i]);
}
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the " << invalid_index_pos << "-th value of 'indices'["
<< indices_ss.str() << "] is out of range[" << input_shape_ss.str() << "].";
}
}
template <typename T, typename S>
bool TensorScatterArithmeticGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
@ -172,6 +242,8 @@ bool TensorScatterArithmeticGpuKernelMod::LaunchKernel(const std::vector<Address
T *update = GetDeviceAddress<T>(inputs, kIndex2);
T *output = GetDeviceAddress<T>(outputs, kIndex0);
(void)CheckIndicesValid(indices);
if (!memcpy_flag_) {
const size_t indices_len = indices_unit_size_ * vec_indices_stride_.size();
std::vector<S> vec_indices_stride_s = std::vector<S>(vec_indices_stride_.begin(), vec_indices_stride_.end());

View File

@ -57,6 +57,8 @@ class TensorScatterArithmeticGpuKernelMod : public NativeGpuKernelMod,
void FreeResource();
bool GetOpType(const BaseOperatorPtr &base_operator);
void UpdateSize();
template <typename S>
void CheckIndicesValid(S *indices);
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
@ -81,6 +83,11 @@ class TensorScatterArithmeticGpuKernelMod : public NativeGpuKernelMod,
void *indices_stride_{nullptr};
void *work_shape_{nullptr};
void *stream_ptr_{nullptr};
size_t slice_size_{1};
size_t batch_size_{1};
size_t inner_size_{1};
size_t total_batch_size_{1};
std::vector<size_t> batch_strides_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -36,7 +36,7 @@ def test_list_getitem_eliminate():
Description: Test list_getitem not be replaced as TupleGetItem in pass 'item_tuple_or_list_eliminate'
Expectation: No exception.
"""
max_iter = Tensor([2])
max_iter = Tensor([1])
k = Tensor([0])
inputs = [Tensor([4]), Tensor([5])]
net = NetWork()