forked from mindspore-Ecosystem/mindspore
adjust gather func's in-params' name and synchronize micro
This commit is contained in:
parent
c554d4a8b1
commit
1173f55061
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,28 +16,28 @@
|
|||
#include <stdio.h>
|
||||
#include "nnacl/base/gather_base.h"
|
||||
|
||||
int Gather(const void *input, int64_t outer_size, int64_t inner_size, int64_t limit, const int *indices,
|
||||
int64_t index_num, void *output, int64_t out_stride) {
|
||||
int Gather(const void *input, int64_t outer_size, int64_t byte_inner_size, int64_t limit, const int *indices,
|
||||
int64_t index_num, void *output, int64_t byte_out_stride) {
|
||||
if (input == NULL || output == NULL || indices == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
const int8_t *int8_in = (int8_t *)input;
|
||||
int8_t *int8_out = (int8_t *)output;
|
||||
int64_t in_stride = inner_size * limit;
|
||||
int64_t in_stride = byte_inner_size * limit;
|
||||
for (int64_t m = 0; m < outer_size; ++m) {
|
||||
int8_t *int8_out_m = int8_out;
|
||||
for (int64_t i = 0; i < index_num; ++i) {
|
||||
int index = indices[i];
|
||||
index = index < 0 ? index + limit : index;
|
||||
if (index < 0 || index >= limit) {
|
||||
memset(int8_out_m, 0, inner_size);
|
||||
memset(int8_out_m, 0, byte_inner_size);
|
||||
} else {
|
||||
memcpy(int8_out_m, int8_in + index * inner_size, inner_size);
|
||||
memcpy(int8_out_m, int8_in + index * byte_inner_size, byte_inner_size);
|
||||
}
|
||||
int8_out_m += inner_size;
|
||||
int8_out_m += byte_inner_size;
|
||||
}
|
||||
int8_in += in_stride;
|
||||
int8_out += out_stride;
|
||||
int8_out += byte_out_stride;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,8 +23,8 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int Gather(const void *input, int64_t outer_size, int64_t inner_size, int64_t limit, const int *indices,
|
||||
int64_t indices_element_size, void *output, int64_t data_size);
|
||||
int Gather(const void *input, int64_t outer_size, int64_t byte_inner_size, int64_t limit, const int *indices,
|
||||
int64_t index_num, void *output, int64_t byte_out_stride);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,6 +32,8 @@ int GatherFP32Coder::DoCode(CoderContext *context) {
|
|||
Tensor *input1 = input_tensors_.at(1);
|
||||
MS_CHECK_PTR(input0);
|
||||
MS_CHECK_PTR(input1);
|
||||
MS_CHECK_TRUE_MSG(input1->data_type() == kNumberTypeInt32 || input1->data_type() == kNumberTypeInt, RET_ERROR,
|
||||
"index's data-type is not int32");
|
||||
// generate code .h .c
|
||||
Collect(context,
|
||||
{
|
||||
|
@ -44,7 +46,6 @@ int GatherFP32Coder::DoCode(CoderContext *context) {
|
|||
NNaclFp32Serializer code;
|
||||
std::vector<int> in_shape = input0->shape();
|
||||
int in_rank = static_cast<int>(in_shape.size());
|
||||
int indices_element_size = input1->ElementsNum();
|
||||
MS_CHECK_PTR(parameter_);
|
||||
int axis = (reinterpret_cast<GatherParameter *>(parameter_))->axis_;
|
||||
MS_CHECK_TRUE(static_cast<int>(in_shape.size()) >= axis, "invalid axis in gather parameter");
|
||||
|
@ -57,16 +58,25 @@ int GatherFP32Coder::DoCode(CoderContext *context) {
|
|||
for (int i = axis + 1; i < in_rank; ++i) {
|
||||
inner_size *= in_shape.at(i);
|
||||
}
|
||||
auto data_size = static_cast<int>(lite::DataTypeSize(input0->data_type()));
|
||||
int64_t byte_inner_size = inner_size * data_size;
|
||||
int indices_element_size = input1->ElementsNum();
|
||||
int64_t byte_out_stride = indices_element_size * byte_inner_size;
|
||||
MS_CHECK_TRUE(thread_num_ > 0, "thread_num_ <= 0");
|
||||
int stride = UP_DIV(outer_size, thread_num_);
|
||||
int start = stride * kDefaultTaskId;
|
||||
int count = MSMIN(stride, outer_size - stride * kDefaultTaskId);
|
||||
|
||||
code << "\t\tconst int8_t *int8_in = (const int8_t *)input0->data();\n";
|
||||
code << "\t\tMS_CHECK_PTR(int8_in);\n";
|
||||
code << "\t\tint8_in += " << std::to_string(start * limit * byte_inner_size) << ";\n";
|
||||
code << "\t\tconst int *index_data = (const int *)input1->data();\n";
|
||||
code << "\t\tMS_CHECK_PTR(index_data);\n";
|
||||
code << "\t\tint8_t *int8_out = (int8_t *)output_tensor_->data();\n";
|
||||
code << "\t\tMS_CHECK_PTR(int8_out);\n";
|
||||
code << "\t\tint8_out += " << std::to_string(start * byte_out_stride) << ";\n";
|
||||
// call the op function
|
||||
if (input0->data_type() == kNumberTypeInt32) {
|
||||
code.CodeFunction("GatherInt32", input0, count, inner_size, limit, input1, indices_element_size, output_tensor_);
|
||||
} else {
|
||||
code.CodeFunction("Gather", input0, count, inner_size, limit, input1, indices_element_size, output_tensor_);
|
||||
}
|
||||
code.CodeFunction("Gather", "int8_in", count, byte_inner_size, limit, "index_data", indices_element_size, "int8_out",
|
||||
byte_out_stride);
|
||||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue