!49310 bugfix for sparsecountsparseoutput

Merge pull request !49310 from 黄勇/bugfix_sparsecountsparseoutput
This commit is contained in:
i-robot 2023-02-24 11:35:26 +00:00 committed by Gitee
commit f7b362445c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 26 additions and 25 deletions

View File

@ -43,28 +43,34 @@ template <class T>
using BatchedMap = std::vector<std::map<int64_t, T>>;
void SparseCountSparseOutputCpuKernelMod::CheckIndicesInBounds(const int64_t *indices_addr, const int64_t *shape_ptr,
size_t indices_length, bool is_1d, size_t rank) const {
for (size_t i = 0; i < indices_length; i++) {
if ((!is_1d) && (rank == 2)) {
if (i % 2 == 0) {
if (indices_addr[i] >= shape_ptr[0]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input index value " << indices_addr[i]
<< " must be in [0, " << shape_ptr[0] << ") as given by dense shape";
size_t indices_length, bool is_1d, size_t rank,
int64_t n_batches) const {
if (rank == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input rank must be greater than 0, but got 0.";
}
indices_length = indices_length / rank;
for (size_t i = 0; i < rank; i++) {
if (!is_1d) {
for (size_t j = 0; j < indices_length; j++) {
if (indices_addr[i + j * rank] >= shape_ptr[i]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input index value " << indices_addr[i + j * rank]
<< " must be in [0, " << shape_ptr[i] << ") as given by dense shape";
break;
}
} else if (indices_addr[i] >= shape_ptr[1]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input index value " << indices_addr[i]
<< " must be in [0, " << shape_ptr[1] << ") as given by dense shape";
break;
}
} else if (is_1d) {
if (indices_addr[i] >= shape_ptr[0]) {
} else {
if (indices_addr[i] >= shape_ptr[i]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input index value " << indices_addr[i]
<< " must be in [0, " << shape_ptr[0] << ") as given by dense shape";
<< " must be in [0, " << shape_ptr[i] << ") as given by dense shape";
break;
}
}
}
if (n_batches <= 0 || n_batches > kMaxBatches) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cannot allocate " << n_batches
<< " batches, dense shape too wide";
}
}
template <typename T>
@ -167,14 +173,9 @@ bool SparseCountSparseOutputCpuKernelMod::LaunchKernel(const std::vector<kernel:
// Check if values and weights are valid
CheckValidValuesAndWeights<I>(values_addr, use_weights);
// Check that index values are in bounds of the dense shape
CheckIndicesInBounds(indices_addr, shape_ptr, indices_length, is_1d, rank);
int64_t n_batches = is_1d ? 1 : shape_ptr[0];
if (n_batches <= 0 || n_batches > kMaxBatches) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cannot allocate " << n_batches
<< " batches, dense shape too wide";
}
// Check that index values are in bounds of the dense shape
CheckIndicesInBounds(indices_addr, shape_ptr, indices_length, is_1d, rank, n_batches);
int64_t max_val = 0;
auto per_batch_counts = BatchedMap<T>(shape_ptr[0]);
@ -217,8 +218,8 @@ bool SparseCountSparseOutputCpuKernelMod::LaunchKernel(const std::vector<kernel:
if (is_1d) {
output_indices[value_pos] = x.first;
} else {
output_indices[value_pos * rank] = i;
output_indices[value_pos * rank + 1] = x.first;
output_indices[value_pos * 2] = i;
output_indices[value_pos * 2 + 1] = x.first;
}
output_values[value_pos] = x.second;
++value_pos;
@ -234,7 +235,7 @@ bool SparseCountSparseOutputCpuKernelMod::LaunchKernel(const std::vector<kernel:
}
// Update output shape based on number of dimensions
int64_t num_dim = static_cast<int64_t>(rank);
int64_t num_dim = static_cast<int64_t>(rank) > 1 ? 2 : 1;
std::vector<int64_t> out_indices_shape = {value_pos, num_dim};
std::vector<int64_t> out_values_shape = {value_pos};
std::vector<int64_t> out_dense_shape_shape = {num_dim};

View File

@ -51,7 +51,7 @@ class SparseCountSparseOutputCpuKernelMod : public NativeCpuKernelMod {
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs);
void CheckIndicesInBounds(const int64_t *indices_addr, const int64_t *shape_ptr, size_t indices_length, bool is_1d,
size_t rank) const;
size_t rank, int64_t n_batches) const;
template <typename T>
void CheckValidValuesAndWeights(const T *values_addr, bool use_weights) const;
using SparseCountSparseOutputFunc =