adjust gather func's in-params' name and synchronize micro

This commit is contained in:
xuanyue 2022-02-25 10:17:42 +08:00
parent c554d4a8b1
commit 1173f55061
3 changed files with 29 additions and 19 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,28 +16,28 @@
#include <stdio.h>
#include "nnacl/base/gather_base.h"
int Gather(const void *input, int64_t outer_size, int64_t inner_size, int64_t limit, const int *indices,
int64_t index_num, void *output, int64_t out_stride) {
int Gather(const void *input, int64_t outer_size, int64_t byte_inner_size, int64_t limit, const int *indices,
int64_t index_num, void *output, int64_t byte_out_stride) {
if (input == NULL || output == NULL || indices == NULL) {
return NNACL_NULL_PTR;
}
const int8_t *int8_in = (int8_t *)input;
int8_t *int8_out = (int8_t *)output;
int64_t in_stride = inner_size * limit;
int64_t in_stride = byte_inner_size * limit;
for (int64_t m = 0; m < outer_size; ++m) {
int8_t *int8_out_m = int8_out;
for (int64_t i = 0; i < index_num; ++i) {
int index = indices[i];
index = index < 0 ? index + limit : index;
if (index < 0 || index >= limit) {
memset(int8_out_m, 0, inner_size);
memset(int8_out_m, 0, byte_inner_size);
} else {
memcpy(int8_out_m, int8_in + index * inner_size, inner_size);
memcpy(int8_out_m, int8_in + index * byte_inner_size, byte_inner_size);
}
int8_out_m += inner_size;
int8_out_m += byte_inner_size;
}
int8_in += in_stride;
int8_out += out_stride;
int8_out += byte_out_stride;
}
return NNACL_OK;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,8 +23,8 @@
#ifdef __cplusplus
extern "C" {
#endif
int Gather(const void *input, int64_t outer_size, int64_t inner_size, int64_t limit, const int *indices,
int64_t indices_element_size, void *output, int64_t data_size);
int Gather(const void *input, int64_t outer_size, int64_t byte_inner_size, int64_t limit, const int *indices,
int64_t index_num, void *output, int64_t byte_out_stride);
#ifdef __cplusplus
}
#endif

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,6 +32,8 @@ int GatherFP32Coder::DoCode(CoderContext *context) {
Tensor *input1 = input_tensors_.at(1);
MS_CHECK_PTR(input0);
MS_CHECK_PTR(input1);
MS_CHECK_TRUE_MSG(input1->data_type() == kNumberTypeInt32 || input1->data_type() == kNumberTypeInt, RET_ERROR,
"index's data-type is not int32");
// generate code .h .c
Collect(context,
{
@ -44,7 +46,6 @@ int GatherFP32Coder::DoCode(CoderContext *context) {
NNaclFp32Serializer code;
std::vector<int> in_shape = input0->shape();
int in_rank = static_cast<int>(in_shape.size());
int indices_element_size = input1->ElementsNum();
MS_CHECK_PTR(parameter_);
int axis = (reinterpret_cast<GatherParameter *>(parameter_))->axis_;
MS_CHECK_TRUE(static_cast<int>(in_shape.size()) >= axis, "invalid axis in gather parameter");
@ -57,16 +58,25 @@ int GatherFP32Coder::DoCode(CoderContext *context) {
for (int i = axis + 1; i < in_rank; ++i) {
inner_size *= in_shape.at(i);
}
auto data_size = static_cast<int>(lite::DataTypeSize(input0->data_type()));
int64_t byte_inner_size = inner_size * data_size;
int indices_element_size = input1->ElementsNum();
int64_t byte_out_stride = indices_element_size * byte_inner_size;
MS_CHECK_TRUE(thread_num_ > 0, "thread_num_ <= 0");
int stride = UP_DIV(outer_size, thread_num_);
int start = stride * kDefaultTaskId;
int count = MSMIN(stride, outer_size - stride * kDefaultTaskId);
code << "\t\tconst int8_t *int8_in = (const int8_t *)input0->data();\n";
code << "\t\tMS_CHECK_PTR(int8_in);\n";
code << "\t\tint8_in += " << std::to_string(start * limit * byte_inner_size) << ";\n";
code << "\t\tconst int *index_data = (const int *)input1->data();\n";
code << "\t\tMS_CHECK_PTR(index_data);\n";
code << "\t\tint8_t *int8_out = (int8_t *)output_tensor_->data();\n";
code << "\t\tMS_CHECK_PTR(int8_out);\n";
code << "\t\tint8_out += " << std::to_string(start * byte_out_stride) << ";\n";
// call the op function
if (input0->data_type() == kNumberTypeInt32) {
code.CodeFunction("GatherInt32", input0, count, inner_size, limit, input1, indices_element_size, output_tensor_);
} else {
code.CodeFunction("Gather", input0, count, inner_size, limit, input1, indices_element_size, output_tensor_);
}
code.CodeFunction("Gather", "int8_in", count, byte_inner_size, limit, "index_data", indices_element_size, "int8_out",
byte_out_stride);
context->AppendCode(code.str());
return RET_OK;
}