!19031 sync bugfix of Gather op from enterprize

Merge pull request !19031 from zuochuanyong/sync_enterprize_fixes
This commit is contained in:
i-robot 2021-06-29 02:44:01 +00:00 committed by Gitee
commit 52a91df23b
3 changed files with 8 additions and 4 deletions

View File

@ -70,7 +70,9 @@ void GatherV2CPUKernel<T>::ParallelRun(int8_t *input_addr, int8_t *output_addr,
tasks.emplace_back(block);
thread_index++;
}
common::ThreadPool::GetInstance().SyncRun(tasks);
if (!common::ThreadPool::GetInstance().SyncRun(tasks)) {
MS_LOG(EXCEPTION) << "SyncRun error!";
}
}
template <typename T>

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include "nnacl/base/gather_base.h"
int Gather(const void *input, int outer_size, int inner_size, int limit, const int *indices, int indices_element_size,
@ -26,7 +26,9 @@ int Gather(const void *input, int outer_size, int inner_size, int limit, const i
int8_t *int8_out_m = int8_out + inner_size * m * indices_element_size * data_size;
for (int i = 0; i < indices_element_size; ++i) {
if (indices[i] < 0 || indices[i] > limit) {
if (indices[i] < 0 || indices[i] >= limit) {
printf("[ERROR] [%s:%d] %s] indices[%d]:%d is out of range [%d, %d)\n", __FILE__, __LINE__, __func__, i,
indices[i], 0, limit);
return NNACL_ERR;
}
memcpy(int8_out_m + i * inner_size * data_size, int8_in_m + indices[i] * inner_size * data_size,

View File

@ -797,7 +797,7 @@ class Gather(Primitive):
The original Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Must be in the range
`[0, input_param.shape[axis])`. The data type can be int32 or int64.
`[0, input_param.shape[axis])` which are only validated on CPU. The data type can be int32 or int64.
- **axis** (int) - Specifies the dimension index to gather indices.
Outputs: