forked from mindspore-Ecosystem/mindspore
!10450 [MSLITE][DEVELOP] modify npu transpose pass
From: @yangruoqi713 Reviewed-by: @zhang_xue_tong,@zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tong,@zhang_xue_tong
This commit is contained in:
commit
ef4108b129
|
@ -21,20 +21,22 @@
|
|||
namespace mindspore::lite {
|
||||
bool CheckFusion(kernel::LiteKernel *kernel) {
|
||||
auto pre_flag =
|
||||
std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *kernel) {
|
||||
return kernel->Type() == schema::PrimitiveType_Nchw2Nhwc && kernel->out_kernels().size() == 1;
|
||||
std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(), [](const kernel::LiteKernel *in_kernel) {
|
||||
return in_kernel->Type() == schema::PrimitiveType_Nchw2Nhwc && in_kernel->out_kernels().size() == 1;
|
||||
});
|
||||
if (!pre_flag) {
|
||||
return false;
|
||||
}
|
||||
auto post_flag =
|
||||
std::all_of(kernel->out_kernels().begin(), kernel->out_kernels().end(), [](const kernel::LiteKernel *kernel) {
|
||||
return kernel->Type() == schema::PrimitiveType_Nhwc2Nchw && kernel->in_kernels().size() == 1;
|
||||
});
|
||||
auto post_flag = std::all_of(
|
||||
kernel->out_kernels().begin(), kernel->out_kernels().end(),
|
||||
[](const kernel::LiteKernel *out_kernel) { return out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw; });
|
||||
return post_flag;
|
||||
}
|
||||
|
||||
bool CheckFormatFusion(kernel::LiteKernel *kernel) {
|
||||
if (kernel->out_kernels().empty()) {
|
||||
return false;
|
||||
}
|
||||
if (kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) {
|
||||
return std::all_of(
|
||||
kernel->out_kernels().begin(), kernel->out_kernels().end(),
|
||||
|
@ -159,38 +161,26 @@ int TransFormAxis(int axis) {
|
|||
}
|
||||
}
|
||||
|
||||
int NPUFusionPass::AddFusion(kernel::LiteKernel *kernel) {
|
||||
if (!CheckFusion(kernel)) {
|
||||
return RET_OK;
|
||||
}
|
||||
void NPUFusionPass::UpdateKernel(kernel::LiteKernel *kernel) {
|
||||
UpdatePreTensors(kernel);
|
||||
UpdatePostTensors(kernel);
|
||||
UpdatePreKernels(kernel);
|
||||
UpdatePostKernels(kernel);
|
||||
}
|
||||
|
||||
int NPUFusionPass::CommonFusion(kernel::LiteKernel *kernel) {
|
||||
UpdateKernel(kernel);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int NPUFusionPass::ConcatFusion(kernel::LiteKernel *kernel) {
|
||||
if (!CheckFusion(kernel)) {
|
||||
return RET_OK;
|
||||
}
|
||||
UpdatePreTensors(kernel);
|
||||
UpdatePostTensors(kernel);
|
||||
UpdatePreKernels(kernel);
|
||||
UpdatePostKernels(kernel);
|
||||
UpdateKernel(kernel);
|
||||
auto concat_param = reinterpret_cast<ConcatParameter *>(kernel->op_parameter());
|
||||
concat_param->axis_ = TransFormAxis(concat_param->axis_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
|
||||
if (kernel->out_kernels().empty()) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (!CheckFormatFusion(kernel)) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto pre_kernel = kernel->in_kernels()[0];
|
||||
auto in_tensor = kernel->in_tensors()[0];
|
||||
auto out_tensor = kernel->out_tensors()[0];
|
||||
|
@ -237,17 +227,28 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
|
|||
}
|
||||
|
||||
int NPUFusionPass::Run() {
|
||||
for (auto kernel : *kernels) {
|
||||
for (size_t i = 0; i < kernels->size(); i++) {
|
||||
auto kernel = (*kernels)[i];
|
||||
if (kernel->Type() == schema::PrimitiveType_Nchw2Nhwc || kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) {
|
||||
if (CheckFormatFusion(kernel)) {
|
||||
i--;
|
||||
FormatFusion(kernel);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!CheckFusion(kernel)) {
|
||||
continue;
|
||||
}
|
||||
switch (kernel->Type()) {
|
||||
case schema::PrimitiveType_Concat:
|
||||
i -= kernel->in_kernels().size();
|
||||
ConcatFusion(kernel);
|
||||
continue;
|
||||
case schema::PrimitiveType_Add:
|
||||
case schema::PrimitiveType_Activation:
|
||||
AddFusion(kernel);
|
||||
continue;
|
||||
case schema::PrimitiveType_Nchw2Nhwc:
|
||||
FormatFusion(kernel);
|
||||
case schema::PrimitiveType_Eltwise:
|
||||
i -= kernel->in_kernels().size();
|
||||
CommonFusion(kernel);
|
||||
continue;
|
||||
default:
|
||||
continue;
|
||||
|
|
|
@ -33,11 +33,12 @@ class NPUFusionPass : public NPUBasePass {
|
|||
int Run() override;
|
||||
|
||||
protected:
|
||||
void RemoveAndFreeKernel(kernel::LiteKernel *cur_kernel);
|
||||
void UpdatePreKernels(kernel::LiteKernel *kernel);
|
||||
void UpdatePostKernels(kernel::LiteKernel *kernel);
|
||||
void RemoveAndFreeKernel(kernel::LiteKernel *cur_kernel);
|
||||
void UpdateKernel(kernel::LiteKernel *kernel);
|
||||
int CommonFusion(kernel::LiteKernel *kernel);
|
||||
int ConcatFusion(kernel::LiteKernel *kernel);
|
||||
int AddFusion(kernel::LiteKernel *kernel);
|
||||
int FormatFusion(kernel::LiteKernel *kernel);
|
||||
|
||||
private:
|
||||
|
|
|
@ -21,7 +21,9 @@ namespace mindspore::lite {
|
|||
using kernel::KERNEL_ARCH::kNPU;
|
||||
enum InsertState { InsertNone, PreInsert, PostInsert };
|
||||
|
||||
std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add};
|
||||
std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
|
||||
schema::PrimitiveType_Eltwise,
|
||||
schema::PrimitiveType_Activation};
|
||||
|
||||
int GetInsertState(kernel::LiteKernel *kernel) {
|
||||
if (npu_insert_nodes.find(kernel->Type()) == npu_insert_nodes.end()) {
|
||||
|
@ -42,15 +44,54 @@ int GetInsertState(kernel::LiteKernel *kernel) {
|
|||
return InsertNone;
|
||||
}
|
||||
|
||||
int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *cur_kernel,
|
||||
std::vector<kernel::LiteKernel *> *all_kernels,
|
||||
int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel,
|
||||
std::vector<kernel::LiteKernel *> *trans_kernels,
|
||||
std::vector<Tensor *> *all_tensors) {
|
||||
for (auto kernel : cur_kernel->in_kernels()) {
|
||||
if (kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) {
|
||||
for (auto in_kernel : kernel->in_kernels()) {
|
||||
if (in_kernel->Type() == schema::PrimitiveType_Nchw2Nhwc) {
|
||||
continue;
|
||||
}
|
||||
auto nhwc_shape = cur_kernel->out_tensors()[0]->shape();
|
||||
auto nhwc_shape = in_kernel->out_tensors()[0]->shape();
|
||||
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
|
||||
|
||||
auto nh2nc_tensor =
|
||||
new Tensor(in_kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
|
||||
std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor};
|
||||
all_tensors->push_back(nh2nc_tensors[0]);
|
||||
|
||||
auto nc2nh_tensor = new Tensor(nh2nc_tensor->data_type(), nhwc_shape, schema::Format_NCHW, Tensor::VAR);
|
||||
std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor};
|
||||
all_tensors->push_back(nc2nh_tensors[0]);
|
||||
|
||||
auto nh2nc_name = in_kernel->name() + "_nh2nc_" + std::to_string(total++);
|
||||
auto *nh2nc_kernel =
|
||||
NPUPassUtils::CreateNhwc2NchwKernel(in_kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name);
|
||||
trans_kernels->push_back(nh2nc_kernel);
|
||||
insert_primitive_.push_back(nh2nc_kernel->GetPrimitive());
|
||||
|
||||
auto nc2nh_name = in_kernel->name() + "_nc2nh_" + std::to_string(total++);
|
||||
auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name);
|
||||
trans_kernels->push_back(nc2nh_kernel);
|
||||
insert_primitive_.push_back(nc2nh_kernel->GetPrimitive());
|
||||
|
||||
NPUPassUtils::UpdateKernel(nh2nc_kernel, {in_kernel}, {nc2nh_kernel}, in_kernel->out_tensors(), nh2nc_tensors);
|
||||
NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {kernel}, nh2nc_tensors, nc2nh_tensors);
|
||||
NPUPassUtils::UpdateNH2NCTransNodePreKernel(in_kernel, nh2nc_kernel, kernel);
|
||||
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(in_kernel, nc2nh_kernel, kernel);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel,
|
||||
std::vector<kernel::LiteKernel *> *trans_kernels,
|
||||
std::vector<Tensor *> *all_tensors) {
|
||||
for (auto out_kernel : kernel->out_kernels()) {
|
||||
if (out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) {
|
||||
continue;
|
||||
}
|
||||
auto nhwc_shape = kernel->out_tensors()[0]->shape();
|
||||
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
|
||||
|
||||
auto nh2nc_tensor = new Tensor(kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
|
||||
std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor};
|
||||
all_tensors->push_back(nh2nc_tensors[0]);
|
||||
|
@ -61,52 +102,18 @@ int NPUInsertTransformPass::InsertPreNode(const InnerContext *context, kernel::L
|
|||
|
||||
auto nh2nc_name = kernel->name() + "_nh2nc_" + std::to_string(total++);
|
||||
auto *nh2nc_kernel = NPUPassUtils::CreateNhwc2NchwKernel(kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name);
|
||||
all_kernels->push_back(nh2nc_kernel);
|
||||
trans_kernels->push_back(nh2nc_kernel);
|
||||
insert_primitive_.push_back(nh2nc_kernel->GetPrimitive());
|
||||
|
||||
auto nc2nh_name = kernel->name() + "_nc2nh_" + std::to_string(total++);
|
||||
auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name);
|
||||
all_kernels->push_back(nc2nh_kernel);
|
||||
trans_kernels->push_back(nc2nh_kernel);
|
||||
insert_primitive_.push_back(nc2nh_kernel->GetPrimitive());
|
||||
|
||||
NPUPassUtils::UpdateKernel(nh2nc_kernel, {kernel}, {nc2nh_kernel}, kernel->out_tensors(), nh2nc_tensors);
|
||||
NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {cur_kernel}, nh2nc_tensors, nc2nh_tensors);
|
||||
NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel, nh2nc_kernel, cur_kernel);
|
||||
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, nc2nh_kernel, cur_kernel);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int NPUInsertTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *cur_kernel,
|
||||
std::vector<kernel::LiteKernel *> *all_kernels,
|
||||
std::vector<Tensor *> *all_tensors) {
|
||||
for (auto out_kernel : cur_kernel->out_kernels()) {
|
||||
if (out_kernel->Type() == schema::PrimitiveType_Nhwc2Nchw) {
|
||||
continue;
|
||||
}
|
||||
auto nhwc_shape = cur_kernel->out_tensors()[0]->shape();
|
||||
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
|
||||
|
||||
auto nh2nc_tensor =
|
||||
new Tensor(cur_kernel->out_tensors()[0]->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
|
||||
std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor};
|
||||
all_tensors->push_back(nh2nc_tensors[0]);
|
||||
|
||||
auto nc2nh_tensor = new Tensor(nh2nc_tensor->data_type(), nhwc_shape, schema::Format_NCHW, Tensor::VAR);
|
||||
std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor};
|
||||
all_tensors->push_back(nc2nh_tensors[0]);
|
||||
|
||||
auto nh2nc_name = cur_kernel->name() + "_nh2nc_" + std::to_string(total++);
|
||||
auto *nh2nc_kernel =
|
||||
NPUPassUtils::CreateNhwc2NchwKernel(cur_kernel->out_tensors(), nh2nc_tensors, context, nh2nc_name);
|
||||
all_kernels->push_back(nh2nc_kernel);
|
||||
insert_primitive_.push_back(nh2nc_kernel->GetPrimitive());
|
||||
auto nc2nh_name = cur_kernel->name() + "_nc2nh_" + std::to_string(total++);
|
||||
auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context, nc2nh_name);
|
||||
all_kernels->push_back(nc2nh_kernel);
|
||||
insert_primitive_.push_back(nc2nh_kernel->GetPrimitive());
|
||||
NPUPassUtils::UpdateKernel(nh2nc_kernel, {cur_kernel}, {nc2nh_kernel}, cur_kernel->out_tensors(), nh2nc_tensors);
|
||||
NPUPassUtils::UpdateKernel(nc2nh_kernel, {nh2nc_kernel}, {out_kernel}, nh2nc_tensors, nc2nh_tensors);
|
||||
NPUPassUtils::UpdateNH2NCTransNodePreKernel(cur_kernel, nh2nc_kernel, out_kernel);
|
||||
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(cur_kernel, nc2nh_kernel, out_kernel);
|
||||
NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel, nh2nc_kernel, out_kernel);
|
||||
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, nc2nh_kernel, out_kernel);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -41,11 +41,11 @@ class NPUInsertTransformPass : public NPUBasePass {
|
|||
int Run() override;
|
||||
|
||||
private:
|
||||
int InsertPreNode(const InnerContext *context, kernel::LiteKernel *cur_kernel,
|
||||
std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors);
|
||||
int InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel,
|
||||
std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors);
|
||||
|
||||
int InsertPostNode(const InnerContext *context, kernel::LiteKernel *cur_kernel,
|
||||
std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors);
|
||||
int InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel,
|
||||
std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors);
|
||||
|
||||
private:
|
||||
int total = 0;
|
||||
|
|
|
@ -100,25 +100,25 @@ void NPUPassUtils::UpdateKernel(kernel::LiteKernel *kernel, const std::vector<ke
|
|||
kernel->set_out_kernels(out_kernels);
|
||||
}
|
||||
|
||||
void NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
|
||||
kernel::LiteKernel *after_kernel) {
|
||||
void NPUPassUtils::UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *pre_kernel, kernel::LiteKernel *trans_kernel,
|
||||
kernel::LiteKernel *kernel) {
|
||||
std::vector<kernel::LiteKernel *> out_kernels;
|
||||
|
||||
for (auto out_kernel : kernel->out_kernels()) {
|
||||
if (out_kernel == after_kernel) {
|
||||
for (auto out_kernel : pre_kernel->out_kernels()) {
|
||||
if (out_kernel == kernel) {
|
||||
out_kernels.push_back(trans_kernel);
|
||||
} else {
|
||||
out_kernels.push_back(out_kernel);
|
||||
}
|
||||
}
|
||||
UpdateKernel(kernel, kernel->in_kernels(), out_kernels, kernel->in_tensors(), kernel->out_tensors());
|
||||
pre_kernel->set_out_kernels(out_kernels);
|
||||
}
|
||||
|
||||
void NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
|
||||
kernel::LiteKernel *next_kernel) {
|
||||
kernel::LiteKernel *post_kernel) {
|
||||
std::vector<kernel::LiteKernel *> cur_out_kernels;
|
||||
for (auto out_kernel : kernel->out_kernels()) {
|
||||
if (out_kernel == next_kernel) {
|
||||
if (out_kernel == post_kernel) {
|
||||
cur_out_kernels.push_back(trans_kernel);
|
||||
} else {
|
||||
cur_out_kernels.push_back(out_kernel);
|
||||
|
@ -130,45 +130,47 @@ void NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, ker
|
|||
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
|
||||
kernel_out_tensor->set_format(schema::Format_NCHW);
|
||||
kernel_out_tensor->set_shape(nchw_shape);
|
||||
UpdateKernel(kernel, kernel->in_kernels(), cur_out_kernels, kernel->in_tensors(), {kernel_out_tensor});
|
||||
kernel->set_out_kernels(cur_out_kernels);
|
||||
kernel->set_out_tensors({kernel_out_tensor});
|
||||
}
|
||||
|
||||
void NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
|
||||
kernel::LiteKernel *before_kernel) {
|
||||
kernel::LiteKernel *pre_kernel) {
|
||||
std::vector<lite::Tensor *> cur_kernel_in_tensors = {trans_kernel->out_tensors()[0]};
|
||||
for (int i = 1; i < kernel->in_tensors().size(); i++) {
|
||||
cur_kernel_in_tensors.push_back(kernel->in_tensors()[i]);
|
||||
}
|
||||
std::vector<kernel::LiteKernel *> cur_in_kernels = {trans_kernel};
|
||||
for (int i = 0; i < kernel->in_kernels().size(); i++) {
|
||||
for (int i = 1; i < kernel->in_kernels().size(); i++) {
|
||||
auto in_kernel = kernel->in_kernels()[i];
|
||||
if (in_kernel != kernel) {
|
||||
cur_in_kernels.push_back(in_kernel);
|
||||
}
|
||||
}
|
||||
UpdateKernel(kernel, cur_in_kernels, kernel->out_kernels(), cur_kernel_in_tensors, kernel->out_tensors());
|
||||
kernel->set_in_kernels(cur_in_kernels);
|
||||
kernel->set_in_tensors({cur_kernel_in_tensors});
|
||||
}
|
||||
|
||||
void NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
|
||||
kernel::LiteKernel *next_kernel) {
|
||||
std::vector<Tensor *> next_in_tensors;
|
||||
for (auto next_in_tensor : next_kernel->in_tensors()) {
|
||||
if (next_in_tensor != kernel->out_tensors()[0]) {
|
||||
next_in_tensors.push_back(next_in_tensor);
|
||||
kernel::LiteKernel *post_kernel) {
|
||||
std::vector<Tensor *> post_in_tensors;
|
||||
for (auto post_in_tensor : post_kernel->in_tensors()) {
|
||||
if (post_in_tensor != kernel->out_tensors()[0]) {
|
||||
post_in_tensors.push_back(post_in_tensor);
|
||||
} else {
|
||||
next_in_tensors.push_back(trans_kernel->out_tensors()[0]);
|
||||
post_in_tensors.push_back(trans_kernel->out_tensors()[0]);
|
||||
}
|
||||
}
|
||||
next_kernel->set_in_tensors(next_in_tensors);
|
||||
std::vector<kernel::LiteKernel *> next_in_kernels;
|
||||
for (auto in_kernel : next_kernel->in_kernels()) {
|
||||
post_kernel->set_in_tensors(post_in_tensors);
|
||||
std::vector<kernel::LiteKernel *> post_in_kernels;
|
||||
for (auto in_kernel : post_kernel->in_kernels()) {
|
||||
if (in_kernel == kernel) {
|
||||
next_in_kernels.push_back(trans_kernel);
|
||||
post_in_kernels.push_back(trans_kernel);
|
||||
} else {
|
||||
next_in_kernels.push_back(in_kernel);
|
||||
post_in_kernels.push_back(in_kernel);
|
||||
}
|
||||
}
|
||||
NPUPassUtils::UpdateKernel(next_kernel, next_in_kernels, next_kernel->out_kernels(), next_in_tensors,
|
||||
next_kernel->out_tensors());
|
||||
post_kernel->set_in_kernels(post_in_kernels);
|
||||
post_kernel->set_in_tensors({post_in_tensors});
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -35,17 +35,17 @@ class NPUPassUtils {
|
|||
const std::vector<kernel::LiteKernel *> &out_kernels,
|
||||
const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors);
|
||||
|
||||
static void UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
|
||||
kernel::LiteKernel *after_kernel);
|
||||
static void UpdateNH2NCTransNodePreKernel(kernel::LiteKernel *pre_kernel, kernel::LiteKernel *trans_kernel,
|
||||
kernel::LiteKernel *kernel);
|
||||
|
||||
static void UpdateNC2NHTransNodePreKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
|
||||
kernel::LiteKernel *next_kernel);
|
||||
kernel::LiteKernel *post_kernel);
|
||||
|
||||
static void UpdateNH2NCTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
|
||||
kernel::LiteKernel *before_kernel);
|
||||
kernel::LiteKernel *pre_kernel);
|
||||
|
||||
static void UpdateNC2NHTransNodeAfterKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
|
||||
kernel::LiteKernel *next_kernel);
|
||||
kernel::LiteKernel *post_kernel);
|
||||
|
||||
private:
|
||||
static PrimitiveC *CreateNchw2NhwcPrimitive();
|
||||
|
|
|
@ -19,51 +19,53 @@
|
|||
#include "src/runtime/agent/npu/npu_manager.h"
|
||||
#include "src/runtime/agent/npu/optimizer/npu_pass_utils.h"
|
||||
namespace mindspore::lite {
|
||||
using kernel::KERNEL_ARCH::kCPU;
|
||||
using kernel::KERNEL_ARCH::kNPU;
|
||||
int NPUTransformPass::InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel,
|
||||
std::vector<kernel::LiteKernel *> *all_kernels,
|
||||
std::vector<kernel::LiteKernel *> *trans_kernels,
|
||||
std::vector<Tensor *> *all_tensors) {
|
||||
bool is_input_kernel = kernel->in_kernels().empty();
|
||||
if (is_input_kernel || kernel->in_kernels()[0]->desc().arch != kNPU ||
|
||||
npu_trans_nodes.find(kernel->in_kernels()[0]->Type()) == npu_trans_nodes.end()) {
|
||||
kernel::LiteKernel *before_kernel = nullptr;
|
||||
kernel::LiteKernel *pre_kernel = nullptr;
|
||||
if (!is_input_kernel) {
|
||||
before_kernel = kernel->in_kernels()[0];
|
||||
pre_kernel = kernel->in_kernels()[0];
|
||||
}
|
||||
// Create pre transform kernel out tensors.
|
||||
|
||||
// Create pre transform kernel's out tensor.
|
||||
auto nhwc_shape = kernel->in_tensors()[0]->shape();
|
||||
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
|
||||
auto tensor = new Tensor(kernel->in_tensors()[0]->data_type(), nchw_shape, schema::Format_NCHW, Tensor::VAR);
|
||||
std::vector<Tensor *> pre_trans_out_tensors = {tensor};
|
||||
all_tensors->push_back(pre_trans_out_tensors[0]);
|
||||
// Replace the output tensor of the previous node
|
||||
|
||||
// Create pre transform kernel: Nhwc2Nchw
|
||||
auto name = kernel->name() + "_pre_trans" + "_Nhwc2Nchw_" + std::to_string(total++);
|
||||
auto *pre_trans_kernel =
|
||||
auto *trans_kernel =
|
||||
NPUPassUtils::CreateNhwc2NchwKernel({kernel->in_tensors()[0]}, pre_trans_out_tensors, context, name);
|
||||
// Insert Nhwc2Nchw into the front of the current queue
|
||||
all_kernels->push_back(pre_trans_kernel);
|
||||
insert_primitive_.push_back(pre_trans_kernel->GetPrimitive());
|
||||
// Replace the output kernel of the previous node
|
||||
|
||||
trans_kernels->push_back(trans_kernel);
|
||||
insert_primitive_.push_back(trans_kernel->GetPrimitive());
|
||||
|
||||
// Set in_kernels, out_kernels, in_tensors,out_tensors for transform kernel
|
||||
std::vector<kernel::LiteKernel *> pre_trans_in_kernel;
|
||||
if (is_input_kernel) {
|
||||
pre_trans_in_kernel = {};
|
||||
} else {
|
||||
pre_trans_in_kernel = {before_kernel};
|
||||
pre_trans_in_kernel = {pre_kernel};
|
||||
}
|
||||
NPUPassUtils::UpdateKernel(pre_trans_kernel, pre_trans_in_kernel, {kernel}, {kernel->in_tensors()[0]},
|
||||
NPUPassUtils::UpdateKernel(trans_kernel, pre_trans_in_kernel, {kernel}, {kernel->in_tensors()[0]},
|
||||
pre_trans_out_tensors);
|
||||
|
||||
if (before_kernel != nullptr) {
|
||||
NPUPassUtils::UpdateNH2NCTransNodePreKernel(before_kernel, pre_trans_kernel, kernel);
|
||||
if (pre_kernel != nullptr) {
|
||||
NPUPassUtils::UpdateNH2NCTransNodePreKernel(pre_kernel, trans_kernel, kernel);
|
||||
}
|
||||
NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel, pre_trans_kernel, before_kernel);
|
||||
NPUPassUtils::UpdateNH2NCTransNodeAfterKernel(kernel, trans_kernel, pre_kernel);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int NPUTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel,
|
||||
std::vector<kernel::LiteKernel *> *all_kernels,
|
||||
std::vector<kernel::LiteKernel *> *trans_kernels,
|
||||
std::vector<Tensor *> *all_tensors) {
|
||||
// Model output does not insert operator
|
||||
if (kernel->out_kernels().empty()) {
|
||||
|
@ -71,27 +73,30 @@ int NPUTransformPass::InsertPostNode(const InnerContext *context, kernel::LiteKe
|
|||
}
|
||||
// Single output multiple references
|
||||
for (int i = 0; i < kernel->out_kernels().size(); i++) {
|
||||
auto next_kernel = kernel->out_kernels().at(i);
|
||||
if (next_kernel->desc().arch == kNPU && npu_trans_nodes.find(next_kernel->Type()) != npu_trans_nodes.end()) {
|
||||
auto post_kernel = kernel->out_kernels().at(i);
|
||||
if (post_kernel->desc().arch == kNPU && npu_trans_nodes.find(post_kernel->Type()) != npu_trans_nodes.end()) {
|
||||
continue;
|
||||
}
|
||||
// Change format the output of the current kernel nhwc->nchw
|
||||
|
||||
// Create post transform kernel's out tensor.
|
||||
auto tensor = new Tensor(kernel->out_tensors()[0]->data_type(), kernel->out_tensors()[0]->shape(),
|
||||
schema::Format_NHWC, Tensor::VAR);
|
||||
std::vector<Tensor *> post_trans_out_tensors = {tensor};
|
||||
all_tensors->push_back(post_trans_out_tensors[0]);
|
||||
// Use the output tensor of the current node as the input tensor of the post-conversion operator
|
||||
|
||||
// Create post transform kernel: Nchw2Nhwc
|
||||
auto name = kernel->name() + "_post_trans" + "_Nchw2Nhwc" + std::to_string(total++);
|
||||
auto *post_trans_kernel =
|
||||
NPUPassUtils::CreateNchw2NhwcKernel(kernel->out_tensors(), post_trans_out_tensors, context, name);
|
||||
// Replace the input tensor of the next node
|
||||
NPUPassUtils::UpdateKernel(post_trans_kernel, {kernel}, {next_kernel}, kernel->out_tensors(),
|
||||
|
||||
// Set in_kernels, out_kernels, in_tensors,out_tensors for transform kernel
|
||||
NPUPassUtils::UpdateKernel(post_trans_kernel, {kernel}, {post_kernel}, kernel->out_tensors(),
|
||||
post_trans_out_tensors);
|
||||
insert_primitive_.push_back(post_trans_kernel->GetPrimitive());
|
||||
// Directly insert in the back, will not affect the topological sort
|
||||
all_kernels->push_back(post_trans_kernel);
|
||||
NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel, post_trans_kernel, next_kernel);
|
||||
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, post_trans_kernel, next_kernel);
|
||||
|
||||
trans_kernels->push_back(post_trans_kernel);
|
||||
NPUPassUtils::UpdateNC2NHTransNodePreKernel(kernel, post_trans_kernel, post_kernel);
|
||||
NPUPassUtils::UpdateNC2NHTransNodeAfterKernel(kernel, post_trans_kernel, post_kernel);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -43,10 +43,10 @@ class NPUTransformPass : public NPUBasePass {
|
|||
|
||||
private:
|
||||
int InsertPreNode(const InnerContext *context, kernel::LiteKernel *kernel,
|
||||
std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors);
|
||||
std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors);
|
||||
|
||||
int InsertPostNode(const InnerContext *context, kernel::LiteKernel *kernel,
|
||||
std::vector<kernel::LiteKernel *> *all_kernels, std::vector<Tensor *> *all_tensors);
|
||||
std::vector<kernel::LiteKernel *> *trans_kernels, std::vector<Tensor *> *all_tensors);
|
||||
|
||||
private:
|
||||
int total = 0;
|
||||
|
|
Loading…
Reference in New Issue