!46949 解决SparseDenseCwiseAdd,SparseDenseCwiseDiv,SparseDenseCwiseMul当索引为负数问题

Merge pull request !46949 from 桂宁馨/SparseDenseCwise_indices_zeros
This commit is contained in:
i-robot 2023-01-05 06:50:21 +00:00 committed by Gitee
commit 44cfcca6f4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 6 additions and 3 deletions

View File

@ -62,7 +62,8 @@ void SparseDenseCwiseAddCpuKernelMod::ComputeAdd(const std::vector<AddressPtr> &
for (int64_t i = 0; i < index_num; i++) {
for (int64_t j = 0; j < dimension; j++) {
if (indices_data[static_cast<size_t>(i * dimension + j)] >= sparse_shape_data[static_cast<size_t>(j)]) {
if (indices_data[static_cast<size_t>(i * dimension + j)] >= sparse_shape_data[static_cast<size_t>(j)] ||
indices_data[i * static_cast<size_t>(dimension) + j] < 0) {
MS_EXCEPTION(ValueError) << "For SparseDenseCwiseAdd, the indices cannot go out of bounds.";
}
}

View File

@ -64,7 +64,8 @@ void SparseDenseCwiseDivCpuKernelMod::ComputeDiv(const std::vector<AddressPtr> &
for (size_t i = 0; i < static_cast<size_t>(index_num); i++) {
for (size_t j = 0; j < static_cast<size_t>(dimension); j++) {
if (indices_data[i * static_cast<size_t>(dimension) + j] >= sparse_shape_data[j]) {
if (indices_data[i * static_cast<size_t>(dimension) + j] >= sparse_shape_data[j] ||
indices_data[i * static_cast<size_t>(dimension) + j] < 0) {
MS_EXCEPTION(ValueError) << "For SparseDenseCwiseDiv, the indices cannot go out of bounds.";
}
}

View File

@ -62,7 +62,8 @@ void SparseDenseCwiseMulCpuKernelMod::ComputeMul(const std::vector<AddressPtr> &
for (int64_t i = 0; i < index_num; i++) {
for (int64_t j = 0; j < dimension; j++) {
if (indices_data[i * dimension + j] >= sparse_shape_data[j]) {
if (indices_data[i * dimension + j] >= sparse_shape_data[j] ||
indices_data[i * static_cast<size_t>(dimension) + j] < 0) {
MS_EXCEPTION(ValueError) << "For SparseDenseCwiseMul, the indices cannot go out of bounds.";
}
}