forked from mindspore-Ecosystem/mindspore
commit
38c2f6b3a3
|
@ -17,7 +17,8 @@
|
|||
#include "nnacl/fp32/concat.h"
|
||||
#include <string.h>
|
||||
|
||||
void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output) {
|
||||
void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output,
|
||||
int task_id, int thread_num) {
|
||||
int before_axis_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
before_axis_size *= inputs_output_shape[0][i];
|
||||
|
@ -33,10 +34,12 @@ void Concat(void **input, int input_num, int axis, int **inputs_output_shape, si
|
|||
for (int i = 0; i < input_num; ++i) {
|
||||
uint8_t *src_base = (input[i]);
|
||||
size_t input_stride = after_axis_size * inputs_output_shape[i][axis];
|
||||
for (int j = 0; j < before_axis_size; ++j) {
|
||||
uint8_t *src = src_base + j * input_stride;
|
||||
uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size;
|
||||
memcpy(dst, src, input_stride);
|
||||
int offset = UP_DIV(input_stride, thread_num);
|
||||
int count = MSMIN(offset, input_stride - offset * task_id);
|
||||
for (int j = 0; j < before_axis_size; j++) {
|
||||
uint8_t *src = src_base + j * input_stride + task_id * offset;
|
||||
uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size + task_id * offset;
|
||||
memcpy(dst, src, count);
|
||||
}
|
||||
axis_offset += inputs_output_shape[i][axis];
|
||||
}
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output);
|
||||
void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output,
|
||||
int task_id, int thread_num);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/thread_pool.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -42,12 +44,7 @@ int ConcatCPUKernel::Init() {
|
|||
|
||||
int ConcatCPUKernel::ReSize() { return ConcatBaseCPUKernel::ReSize(); }
|
||||
|
||||
int ConcatCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
int ConcatCPUKernel::DoConcat(int task_id) {
|
||||
auto input_num = in_tensors_.size();
|
||||
std::vector<void *> inputs_addr(input_num, nullptr);
|
||||
std::vector<int *> inputs_output_shape(input_num + 1, nullptr);
|
||||
|
@ -63,7 +60,27 @@ int ConcatCPUKernel::Run() {
|
|||
auto output_addr = out_tensors_.at(0)->MutableData();
|
||||
|
||||
Concat(reinterpret_cast<void **>(inputs_addr.data()), input_num, axis_, inputs_output_shape.data(),
|
||||
output_shape.size(), output_addr);
|
||||
output_shape.size(), output_addr, task_id, thread_count_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConcatsRun(void *cdata, int task_id) {
|
||||
auto concat_kernel = reinterpret_cast<ConcatCPUKernel *>(cdata);
|
||||
auto error_code = concat_kernel->DoConcat(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConcatsRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConcatCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ConcatsRun, this, thread_count_);
|
||||
return error_code;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -35,9 +35,8 @@ class ConcatCPUKernel : public ConcatBaseCPUKernel {
|
|||
~ConcatCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
|
||||
int ReSize() override;
|
||||
|
||||
int DoConcat(int task_id);
|
||||
int Run() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue