!5987 optimize op concat

Merge pull request !5987 from 陶云浩/concat
This commit is contained in:
mindspore-ci-bot 2020-09-14 16:20:42 +08:00 committed by Gitee
commit 38c2f6b3a3
4 changed files with 35 additions and 15 deletions

View File

@ -17,7 +17,8 @@
#include "nnacl/fp32/concat.h"
#include <string.h>
void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output) {
void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output,
int task_id, int thread_num) {
int before_axis_size = 1;
for (int i = 0; i < axis; ++i) {
before_axis_size *= inputs_output_shape[0][i];
@ -33,10 +34,12 @@ void Concat(void **input, int input_num, int axis, int **inputs_output_shape, si
for (int i = 0; i < input_num; ++i) {
uint8_t *src_base = (input[i]);
size_t input_stride = after_axis_size * inputs_output_shape[i][axis];
for (int j = 0; j < before_axis_size; ++j) {
uint8_t *src = src_base + j * input_stride;
uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size;
memcpy(dst, src, input_stride);
int offset = UP_DIV(input_stride, thread_num);
int count = MSMIN(offset, input_stride - offset * task_id);
for (int j = 0; j < before_axis_size; j++) {
uint8_t *src = src_base + j * input_stride + task_id * offset;
uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size + task_id * offset;
memcpy(dst, src, count);
}
axis_offset += inputs_output_shape[i][axis];
}

View File

@ -22,7 +22,8 @@
#ifdef __cplusplus
extern "C" {
#endif
void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output);
void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output,
int task_id, int thread_num);
#ifdef __cplusplus
}
#endif

View File

@ -20,6 +20,8 @@
#include "src/kernel_registry.h"
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/thread_pool.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -42,12 +44,7 @@ int ConcatCPUKernel::Init() {
int ConcatCPUKernel::ReSize() { return ConcatBaseCPUKernel::ReSize(); }
int ConcatCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int ConcatCPUKernel::DoConcat(int task_id) {
auto input_num = in_tensors_.size();
std::vector<void *> inputs_addr(input_num, nullptr);
std::vector<int *> inputs_output_shape(input_num + 1, nullptr);
@ -63,7 +60,27 @@ int ConcatCPUKernel::Run() {
auto output_addr = out_tensors_.at(0)->MutableData();
Concat(reinterpret_cast<void **>(inputs_addr.data()), input_num, axis_, inputs_output_shape.data(),
output_shape.size(), output_addr);
output_shape.size(), output_addr, task_id, thread_count_);
return RET_OK;
}
int ConcatsRun(void *cdata, int task_id) {
auto concat_kernel = reinterpret_cast<ConcatCPUKernel *>(cdata);
auto error_code = concat_kernel->DoConcat(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "ConcatsRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int ConcatCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ConcatsRun, this, thread_count_);
return error_code;
}
} // namespace mindspore::kernel

View File

@ -35,9 +35,8 @@ class ConcatCPUKernel : public ConcatBaseCPUKernel {
~ConcatCPUKernel() = default;
int Init() override;
int ReSize() override;
int DoConcat(int task_id);
int Run() override;
};
} // namespace mindspore::kernel