!40806 修复csr_softmax算子nan值报错

Merge pull request !40806 from 王程浩/master
This commit is contained in:
i-robot 2022-08-24 08:43:09 +00:00 committed by Gitee
commit c74662c845
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 3 additions and 3 deletions

View File

@ -85,8 +85,8 @@ void SparseMatrixSoftmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
auto *input_logits_values = reinterpret_cast<T *>(inputs[logits_values]->addr);
auto *input_logits_dense_shape = reinterpret_cast<int *>(inputs[logits_dense_shape]->addr);
auto *input_logits_col_indices = reinterpret_cast<int *>(inputs[logits_col_indices]->addr);
float total = 0;
float MAX = input_logits_values[0];
T total = 0;
T MAX = input_logits_values[0];
int row_index = input_logits_dense_shape[0];
int start = 0;
for (int i = 1; i <= row_index; i++) {

View File

@ -27,7 +27,7 @@ __global__ void SparseMatrixSoftmaxKernel(int rows, IndexType *indptr, DataType
IndexType begin = indptr[id];
IndexType end = indptr[id + 1];
DataType row_max = std::numeric_limits<int>::min();
DataType row_max = values[begin];
for (int r_i = begin; r_i < end; ++r_i) {
row_max = max(row_max, values[r_i]);
}