!10450 [MSLITE][DEVELOP] modify npu transpose pass

From: @yangruoqi713
Reviewed-by: @zhang_xue_tong,@zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong,@zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-12-26 14:41:49 +08:00 committed by Gitee
commit ef4108b129
8 changed files with 155 additions and 139 deletions

View File

@ -21,20 +21,22 @@
namespace mindspore::lite {
bool CheckFusion(kernel::LiteKernel *kernel) {
auto pre_flag =
std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *kernel) {
return kernel->Type() == schema::PrimitiveType_Nchw2Nhwc && kernel->out_kernels().size() == 1;
std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *in_kernel) {
return in_kernel->Type() == schema::PrimitiveType_Nchw2Nhwc && in_kernel->out_kernels().size() == 1;
});
if (!pre_flag) {
return false;
}
auto post_flag =
std::all_of(kernel->out_kernels().begin(), kernel->out_kernels().end(), [](const kernel::LiteKernel *kernel) {
return kernel->Type() == schema::PrimitiveType_Nhwc2Nchw && kernel->in_kernels().size() == 1;
});
auto post_flag = std::all_of(
kernel->out_kernels().begin(), kernel->out_kernels().end(),
[](const kernel::LiteKernel *out_kernel) { return out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw; });
return post_flag;
}
bool CheckFormatFusion(kernel::LiteKernel *kernel) {
if (kernel->out_kernels().empty()) {
return false;
}
if (kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) {
return std::all_of(
kernel->out_kernels().begin(), kernel->out_kernels().end(),
@ -159,38 +161,26 @@ int TransFormAxis(int axis) {
}
}
int NPUFusionPass::AddFusion(kernel::LiteKernel *kernel) {
if (!CheckFusion(kernel)) {
return RET_OK;
}
void NPUFusionPass::UpdateKernel(kernel::LiteKernel *kernel) {
UpdatePreTensors(kernel);
UpdatePostTensors(kernel);
UpdatePreKernels(kernel);
UpdatePostKernels(kernel);
}
int NPUFusionPass::CommonFusion(kernel::LiteKernel *kernel) {
UpdateKernel(kernel);
return RET_OK;
}
int NPUFusionPass::ConcatFusion(kernel::LiteKernel *kernel) {
if (!CheckFusion(kernel)) {
return RET_OK;
}
UpdatePreTensors(kernel);
UpdatePostTensors(kernel);
UpdatePreKernels(kernel);
UpdatePostKernels(kernel);
UpdateKernel(kernel);
auto concat_param = reinterpret_cast<ConcatParameter *>(kernel->op_parameter());
concat_param->axis_ = TransFormAxis(concat_param->axis_);
return RET_OK;
}
int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
if (kernel->out_kernels().empty()) {
return RET_OK;
}
if (!CheckFormatFusion(kernel)) {
return RET_OK;
}
auto pre_kernel = kernel->in_kernels()[0];
auto in_tensor = kernel->in_tensors()[0];
auto out_tensor = kernel->out_tensors()[0];
@ -237,17 +227,28 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
}
int NPUFusionPass::Run() {
for (auto kernel : *kernels) {
for (size_t i = 0; i < kernels->size(); i++) {
auto kernel = (*kernels)[i];
if (kernel->Type() == schema::PrimitiveType_Nchw2Nhwc || kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) {
if (CheckFormatFusion(kernel)) {
i--;
FormatFusion(kernel);
}
continue;
}
if (!CheckFusion(kernel)) {
continue;
}
switch (kernel->Type()) {
case schema::PrimitiveType_Concat:
i -= kernel->in_kernels().size();
ConcatFusion(kernel);
continue;
case schema::PrimitiveType_Add:
case schema::PrimitiveType_Activation:
AddFusion(kernel);
continue;
case schema::PrimitiveType_Nchw2Nhwc:
FormatFusion(kernel);
case schema::PrimitiveType_Eltwise:
i -= kernel->in_kernels().size();
CommonFusion(kernel);
continue;
default:
continue;

View File

@ -33,11 +33,12 @@ class NPUFusionPass : public NPUBasePass {
int Run() override;
protected:
void RemoveAndFreeKernel(kernel::LiteKernel *cur_kernel);
void UpdatePreKernels(kernel::LiteKernel *kernel);
void UpdatePostKernels(kernel::LiteKernel *kernel);
void RemoveAndFreeKernel(kernel::LiteKernel *cur_kernel);
void UpdateKernel(kernel::LiteKernel *kernel);
int CommonFusion(kernel::LiteKernel *kernel);
int ConcatFusion(kernel::LiteKernel *kernel);
int AddFusion(kernel::LiteKernel *kernel);
int FormatFusion(kernel::LiteKernel *kernel);
private:

View File

@ -21,7 +21,9 @@ namespace mindspore::lite {
using kernel::KERNEL_ARCH::kNPU;
enum InsertState { InsertNone, PreInsert, PostInsert };
std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add};
std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
schema::PrimitiveType_Eltwise,
schema::PrimitiveType_Activation};
int GetInsertState(kernel::LiteKernel *kernel) {
if (npu_insert_nodes.find(kernel->Type()) == npu_insert_nodes.end()) {
@ -42,15 +44,54 @@ int GetInsertState(kernel::LiteKernel *kernel) {
return InsertNone;
}
int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *cur_kernel,
std::vector<kernel::LiteKernel *> *all_kernels,
int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel,
std::vector<kernel::LiteKernel *> *trans_kernels,
std::vector<Tensor *> *all_tensors) {
for (auto kernel : cur_kernel->in_kernels()) {
if (kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) {
for (auto in_kernel : kernel->in_kernels()) {
if (in_kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) {
continue;
}
auto nhwc_shape = cur_kernel->out_tensors()[0]->shape();
auto nhwc_shape = in_kernel->out_tensors()[0]->shape();
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
auto nh2nc_tensor =
new Tensor(in_kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor};
all_tensors->push_back(nh2nc_tensors[0]);
auto nc2nh_tensor = new Tensor(nh2nc_tensor->data_type(), nhwc_shape, schema::Format_NCHW, Tensor::VAR);
std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor};
all_tensors->push_back(nc2nh_tensors[0]);
auto nh2nc_name = in_kernel->name() + "_nh2nc_" + std::to_string(total++);
auto *nh2nc_kernel =
NPUPassUtils::CreateNhwc2NchwKernel(in_kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name);
trans_kernels->push_back(nh2nc_kernel);
insert_primitive_.push_back(nh2nc_kernel->GetPrimitive());
auto nc2nh_name = in_kernel->name() + "_nc2nh_" + std::to_string(total++);
auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name);
trans_kernels->push_back(nc2nh_kernel);
insert_primitive_.push_back(nc2nh_kernel->GetPrimitive());
NPUPassUtils::UpdateKernel(nh2nc_kernel, {in_kernel}, {nc2nh_kernel}, in_kernel->out_tensors(), nh2nc_tensors);
NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {kernel}, nh2nc_tensors, nc2nh_tensors);
NPUPassUtils::UpdateNH2NCTransNodePreKernel(in_kernel, nh2nc_kernel, kernel);
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(in_kernel, nc2nh_kernel, kernel);
}
return RET_OK;
}
int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel,
std::vector<kernel::LiteKernel *> *trans_kernels,
std::vector<Tensor *> *all_tensors) {
for (auto out_kernel : kernel->out_kernels()) {
if (out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) {
continue;
}
auto nhwc_shape = kernel->out_tensors()[0]->shape();
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
auto nh2nc_tensor = new Tensor(kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor};
all_tensors->push_back(nh2nc_tensors[0]);
@ -61,52 +102,18 @@ int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::L
auto nh2nc_name = kernel->name() + "_nh2nc_" + std::to_string(total++);
auto *nh2nc_kernel = NPUPassUtils::CreateNhwc2NchwKernel(kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name);
all_kernels->push_back(nh2nc_kernel);
trans_kernels->push_back(nh2nc_kernel);
insert_primitive_.push_back(nh2nc_kernel->GetPrimitive());
auto nc2nh_name = kernel->name() + "_nc2nh_" + std::to_string(total++);
auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name);
all_kernels->push_back(nc2nh_kernel);
trans_kernels->push_back(nc2nh_kernel);
insert_primitive_.push_back(nc2nh_kernel->GetPrimitive());
NPUPassUtils::UpdateKernel(nh2nc_kernel, {kernel}, {nc2nh_kernel}, kernel->out_tensors(), nh2nc_tensors);
NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {cur_kernel}, nh2nc_tensors, nc2nh_tensors);
NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel, nh2nc_kernel, cur_kernel);
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, nc2nh_kernel, cur_kernel);
}
return RET_OK;
}
int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *cur_kernel,
std::vector<kernel::LiteKernel *> *all_kernels,
std::vector<Tensor *> *all_tensors) {
for (auto out_kernel : cur_kernel->out_kernels()) {
if (out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) {
continue;
}
auto nhwc_shape = cur_kernel->out_tensors()[0]->shape();
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
auto nh2nc_tensor =
new Tensor(cur_kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor};
all_tensors->push_back(nh2nc_tensors[0]);
auto nc2nh_tensor = new Tensor(nh2nc_tensor->data_type(), nhwc_shape, schema::Format_NCHW, Tensor::VAR);
std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor};
all_tensors->push_back(nc2nh_tensors[0]);
auto nh2nc_name = cur_kernel->name() + "_nh2nc_" + std::to_string(total++);
auto *nh2nc_kernel =
NPUPassUtils::CreateNhwc2NchwKernel(cur_kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name);
all_kernels->push_back(nh2nc_kernel);
insert_primitive_.push_back(nh2nc_kernel->GetPrimitive());
auto nc2nh_name = cur_kernel->name() + "_nc2nh_" + std::to_string(total++);
auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name);
all_kernels->push_back(nc2nh_kernel);
insert_primitive_.push_back(nc2nh_kernel->GetPrimitive());
NPUPassUtils::UpdateKernel(nh2nc_kernel, {cur_kernel}, {nc2nh_kernel}, cur_kernel->out_tensors(), nh2nc_tensors);
NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {out_kernel}, nh2nc_tensors, nc2nh_tensors);
NPUPassUtils::UpdateNH2NCTransNodePreKernel(cur_kernel, nh2nc_kernel, out_kernel);
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(cur_kernel, nc2nh_kernel, out_kernel);
NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel, nh2nc_kernel, out_kernel);
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, nc2nh_kernel, out_kernel);
}
return RET_OK;
}

View File

@ -41,11 +41,11 @@ class NPUInsertTransformPass : public NPUBasePass {
int Run() override;
private:
int InsertPreNode(const InnerContext *context, kernel::LiteKernel *cur_kernel,
std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors);
int InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel,
std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors);
int InsertPostNode(const InnerContext *context, kernel::LiteKernel *cur_kernel,
std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors);
int InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel,
std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors);
private:
int total = 0;

View File

@ -100,25 +100,25 @@ void NPUPassUtils::UpdateKernel(kernel::LiteKernel *kernel, const std::vector<ke
kernel->set_out_kernels(out_kernels);
}
void NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *after_kernel) {
void NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *pre_kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *kernel) {
std::vector<kernel::LiteKernel *> out_kernels;
for (auto out_kernel : kernel->out_kernels()) {
if (out_kernel == after_kernel) {
for (auto out_kernel : pre_kernel->out_kernels()) {
if (out_kernel == kernel) {
out_kernels.push_back(trans_kernel);
} else {
out_kernels.push_back(out_kernel);
}
}
UpdateKernel(kernel, kernel->in_kernels(), out_kernels, kernel->in_tensors(), kernel->out_tensors());
pre_kernel->set_out_kernels(out_kernels);
}
void NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *next_kernel) {
kernel::LiteKernel *post_kernel) {
std::vector<kernel::LiteKernel *> cur_out_kernels;
for (auto out_kernel : kernel->out_kernels()) {
if (out_kernel == next_kernel) {
if (out_kernel == post_kernel) {
cur_out_kernels.push_back(trans_kernel);
} else {
cur_out_kernels.push_back(out_kernel);
@ -130,45 +130,47 @@ void NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, ker
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
kernel_out_tensor->set_format(schema::Format_NCHW);
kernel_out_tensor->set_shape(nchw_shape);
UpdateKernel(kernel, kernel->in_kernels(), cur_out_kernels, kernel->in_tensors(), {kernel_out_tensor});
kernel->set_out_kernels(cur_out_kernels);
kernel->set_out_tensors({kernel_out_tensor});
}
void NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *before_kernel) {
kernel::LiteKernel *pre_kernel) {
std::vector<lite::Tensor *> cur_kernel_in_tensors = {trans_kernel->out_tensors()[0]};
for (int i = 1; i < kernel->in_tensors().size(); i++) {
cur_kernel_in_tensors.push_back(kernel->in_tensors()[i]);
}
std::vector<kernel::LiteKernel *> cur_in_kernels = {trans_kernel};
for (int i = 0; i < kernel->in_kernels().size(); i++) {
for (int i = 1; i < kernel->in_kernels().size(); i++) {
auto in_kernel = kernel->in_kernels()[i];
if (in_kernel != kernel) {
cur_in_kernels.push_back(in_kernel);
}
}
UpdateKernel(kernel, cur_in_kernels, kernel->out_kernels(), cur_kernel_in_tensors, kernel->out_tensors());
kernel->set_in_kernels(cur_in_kernels);
kernel->set_in_tensors({cur_kernel_in_tensors});
}
void NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *next_kernel) {
std::vector<Tensor *> next_in_tensors;
for (auto next_in_tensor : next_kernel->in_tensors()) {
if (next_in_tensor != kernel->out_tensors()[0]) {
next_in_tensors.push_back(next_in_tensor);
kernel::LiteKernel *post_kernel) {
std::vector<Tensor *> post_in_tensors;
for (auto post_in_tensor : post_kernel->in_tensors()) {
if (post_in_tensor != kernel->out_tensors()[0]) {
post_in_tensors.push_back(post_in_tensor);
} else {
next_in_tensors.push_back(trans_kernel->out_tensors()[0]);
post_in_tensors.push_back(trans_kernel->out_tensors()[0]);
}
}
next_kernel->set_in_tensors(next_in_tensors);
std::vector<kernel::LiteKernel *> next_in_kernels;
for (auto in_kernel : next_kernel->in_kernels()) {
post_kernel->set_in_tensors(post_in_tensors);
std::vector<kernel::LiteKernel *> post_in_kernels;
for (auto in_kernel : post_kernel->in_kernels()) {
if (in_kernel == kernel) {
next_in_kernels.push_back(trans_kernel);
post_in_kernels.push_back(trans_kernel);
} else {
next_in_kernels.push_back(in_kernel);
post_in_kernels.push_back(in_kernel);
}
}
NPUPassUtils::UpdateKernel(next_kernel, next_in_kernels, next_kernel->out_kernels(), next_in_tensors,
next_kernel->out_tensors());
post_kernel->set_in_kernels(post_in_kernels);
post_kernel->set_in_tensors({post_in_tensors});
}
} // namespace mindspore::lite

View File

@ -35,17 +35,17 @@ class NPUPassUtils {
const std::vector<kernel::LiteKernel *> &out_kernels,
const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors);
static void UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *after_kernel);
static void UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *pre_kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *kernel);
static void UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *next_kernel);
kernel::LiteKernel *post_kernel);
static void UpdateNH2NCTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *before_kernel);
kernel::LiteKernel *pre_kernel);
static void UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *next_kernel);
kernel::LiteKernel *post_kernel);
private:
static PrimitiveC *CreateNchw2NhwcPrimitive();

View File

@ -19,51 +19,53 @@
#include "src/runtime/agent/npu/npu_manager.h"
#include "src/runtime/agent/npu/optimizer/npu_pass_utils.h"
namespace mindspore::lite {
using kernel::KERNEL_ARCH::kCPU;
using kernel::KERNEL_ARCH::kNPU;
int NPUTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel,
std::vector<kernel::LiteKernel *> *all_kernels,
std::vector<kernel::LiteKernel *> *trans_kernels,
std::vector<Tensor *> *all_tensors) {
bool is_input_kernel = kernel->in_kernels().empty();
if (is_input_kernel || kernel->in_kernels()[0]->desc().arch != kNPU ||
npu_trans_nodes.find(kernel->in_kernels()[0]->Type()) == npu_trans_nodes.end()) {
kernel::LiteKernel *before_kernel = nullptr;
kernel::LiteKernel *pre_kernel = nullptr;
if (!is_input_kernel) {
before_kernel = kernel->in_kernels()[0];
pre_kernel = kernel->in_kernels()[0];
}
// Create pre transform kernel out tensors.
// Create pre transform kernel's out tensor.
auto nhwc_shape = kernel->in_tensors()[0]->shape();
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
auto tensor = new Tensor(kernel->in_tensors()[0]->data_type(), nchw_shape, schema::Format_NCHW, Tensor::VAR);
std::vector<Tensor *> pre_trans_out_tensors = {tensor};
all_tensors->push_back(pre_trans_out_tensors[0]);
// Replace the output tensor of the previous node
// Create pre transform kernel: Nhwc2Nchw
auto name = kernel->name() + "_pre_trans" + "_Nhwc2Nchw_" + std::to_string(total++);
auto *pre_trans_kernel =
auto *trans_kernel =
NPUPassUtils::CreateNhwc2NchwKernel({kernel->in_tensors()[0]}, pre_trans_out_tensors, context, name);
// Insert Nhwc2Nchw into the front of the current queue
all_kernels->push_back(pre_trans_kernel);
insert_primitive_.push_back(pre_trans_kernel->GetPrimitive());
// Replace the output kernel of the previous node
trans_kernels->push_back(trans_kernel);
insert_primitive_.push_back(trans_kernel->GetPrimitive());
// Set in_kernels, out_kernels, in_tensors,out_tensors for transform kernel
std::vector<kernel::LiteKernel *> pre_trans_in_kernel;
if (is_input_kernel) {
pre_trans_in_kernel = {};
} else {
pre_trans_in_kernel = {before_kernel};
pre_trans_in_kernel = {pre_kernel};
}
NPUPassUtils::UpdateKernel(pre_trans_kernel, pre_trans_in_kernel, {kernel}, {kernel->in_tensors()[0]},
NPUPassUtils::UpdateKernel(trans_kernel, pre_trans_in_kernel, {kernel}, {kernel->in_tensors()[0]},
pre_trans_out_tensors);
if (before_kernel != nullptr) {
NPUPassUtils::UpdateNH2NCTransNodePreKernel(before_kernel, pre_trans_kernel, kernel);
if (pre_kernel != nullptr) {
NPUPassUtils::UpdateNH2NCTransNodePreKernel(pre_kernel, trans_kernel, kernel);
}
NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel, pre_trans_kernel, before_kernel);
NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel, trans_kernel, pre_kernel);
}
return RET_OK;
}
int NPUTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel,
std::vector<kernel::LiteKernel *> *all_kernels,
std::vector<kernel::LiteKernel *> *trans_kernels,
std::vector<Tensor *> *all_tensors) {
// Model output does not insert operator
if (kernel->out_kernels().empty()) {
@ -71,27 +73,30 @@ int NPUTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKe
}
// Single output multiple references
for (int i = 0; i < kernel->out_kernels().size(); i++) {
auto next_kernel = kernel->out_kernels().at(i);
if (next_kernel->desc().arch == kNPU && npu_trans_nodes.find(next_kernel->Type()) != npu_trans_nodes.end()) {
auto post_kernel = kernel->out_kernels().at(i);
if (post_kernel->desc().arch == kNPU && npu_trans_nodes.find(post_kernel->Type()) != npu_trans_nodes.end()) {
continue;
}
// Change format the output of the current kernel nhwc->nchw
// Create post transform kernel's out tensor.
auto tensor = new Tensor(kernel->out_tensors()[0]->data_type(), kernel->out_tensors()[0]->shape(),
schema::Format_NHWC, Tensor::VAR);
std::vector<Tensor *> post_trans_out_tensors = {tensor};
all_tensors->push_back(post_trans_out_tensors[0]);
// Use the output tensor of the current node as the input tensor of the post-conversion operator
// Create post transform kernel: Nchw2Nhwc
auto name = kernel->name() + "_post_trans" + "_Nchw2Nhwc" + std::to_string(total++);
auto *post_trans_kernel =
NPUPassUtils::CreateNchw2NhwcKernel(kernel->out_tensors(), post_trans_out_tensors, context, name);
// Replace the input tensor of the next node
NPUPassUtils::UpdateKernel(post_trans_kernel, {kernel}, {next_kernel}, kernel->out_tensors(),
// Set in_kernels, out_kernels, in_tensors,out_tensors for transform kernel
NPUPassUtils::UpdateKernel(post_trans_kernel, {kernel}, {post_kernel}, kernel->out_tensors(),
post_trans_out_tensors);
insert_primitive_.push_back(post_trans_kernel->GetPrimitive());
// Directly insert in the back, will not affect the topological sort
all_kernels->push_back(post_trans_kernel);
NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel, post_trans_kernel, next_kernel);
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, post_trans_kernel, next_kernel);
trans_kernels->push_back(post_trans_kernel);
NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel, post_trans_kernel, post_kernel);
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, post_trans_kernel, post_kernel);
}
return RET_OK;
}

View File

@ -43,10 +43,10 @@ class NPUTransformPass : public NPUBasePass {
private:
int InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel,
std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors);
std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors);
int InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel,
std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors);
std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors);
private:
int total = 0;