!12468 [MSLITE][Develop] fix bug of npu kernel fusion

From: @yangruoqi713
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-02-25 09:35:48 +08:00 committed by Gitee
commit cf8279267c
5 changed files with 51 additions and 22 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -197,7 +197,11 @@ int NPUFusionPass::ConcatFusion(kernel::LiteKernel *kernel) {
}
int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
auto pre_kernel = kernel->in_kernels()[0];
auto is_input_kernel = kernel->in_kernels().empty();
kernel::LiteKernel *pre_kernel = nullptr;
if (!is_input_kernel) {
pre_kernel = kernel->in_kernels()[0];
}
auto in_tensor = kernel->in_tensors()[0];
std::vector<kernel::LiteKernel *> pre_insert_kernels;
for (const auto &trans_kernel : kernel->out_kernels()) {
@ -225,7 +229,11 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
auto post_in_kernels = post_kernel->in_kernels();
for (size_t i = 0; i < post_in_kernels.size(); i++) {
if (post_in_kernels[i] == trans_kernel) {
post_in_kernels[i] = pre_kernel;
if (is_input_kernel) {
post_in_kernels.erase(post_in_kernels.begin() + i);
} else {
post_in_kernels[i] = pre_kernel;
}
break;
}
}
@ -234,16 +242,18 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
}
RemoveAndFreeKernel(trans_kernel);
}
auto pre_out_kernels = pre_kernel->out_kernels();
size_t index = 0;
for (; index < pre_out_kernels.size(); index++) {
if (pre_out_kernels[index] == kernel) {
pre_out_kernels.erase(pre_out_kernels.begin() + index);
break;
if (!is_input_kernel) {
auto pre_out_kernels = pre_kernel->out_kernels();
size_t index = 0;
for (; index < pre_out_kernels.size(); index++) {
if (pre_out_kernels[index] == kernel) {
pre_out_kernels.erase(pre_out_kernels.begin() + index);
break;
}
}
pre_out_kernels.insert(pre_out_kernels.begin() + index, pre_insert_kernels.begin(), pre_insert_kernels.end());
pre_kernel->set_out_kernels(pre_out_kernels);
}
pre_out_kernels.insert(pre_out_kernels.begin() + index, pre_insert_kernels.begin(), pre_insert_kernels.end());
pre_kernel->set_out_kernels(pre_out_kernels);
RemoveAndFreeKernel(kernel);
return RET_OK;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -116,28 +116,30 @@ int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteK
std::vector<int> nhwc_shape = in_tensor->shape();
std::vector<int> nchw_shape = {nhwc_shape[0], nhwc_shape[3], nhwc_shape[1], nhwc_shape[2]};
auto nh2nc_name = kernel_name + "_nh2nc_" + std::to_string(total++);
auto nh2nc_tensor = new (std::nothrow) Tensor(in_tensor->data_type(), nchw_shape, schema::Format_NHWC, Tensor::VAR);
if (nh2nc_tensor == nullptr) {
MS_LOG(ERROR) << "New nchw tensor failed when inserting nchw2nhwc kernel.";
return RET_ERROR;
}
nh2nc_tensor->set_tensor_name(nh2nc_name + "/output0");
std::vector<Tensor *> nh2nc_tensors = {nh2nc_tensor};
all_tensors_->push_back(nh2nc_tensors[0]);
auto nc2nh_name = kernel_name + "_nc2nh_" + std::to_string(total++);
auto nc2nh_tensor = new (std::nothrow) Tensor(in_tensor->data_type(), nhwc_shape, schema::Format_NCHW, Tensor::VAR);
if (nc2nh_tensor == nullptr) {
MS_LOG(ERROR) << "New nhwc tensor failed when inserting nhwc2nchw kernel.";
return RET_ERROR;
}
nc2nh_tensor->set_tensor_name(nc2nh_name + "/output0");
std::vector<Tensor *> nc2nh_tensors = {nc2nh_tensor};
all_tensors_->push_back(nc2nh_tensors[0]);
auto nh2nc_name = kernel_name + "_nh2nc_" + std::to_string(total++);
auto *nh2nc_kernel = NPUPassUtils::CreateNhwc2NchwKernel({in_tensor}, nh2nc_tensors, context_, nh2nc_name);
trans_kernels->push_back(nh2nc_kernel);
insert_primitive_.push_back(nh2nc_kernel->GetPrimitive());
auto nc2nh_name = kernel_name + "_nc2nh_" + std::to_string(total++);
auto *nc2nh_kernel = NPUPassUtils::CreateNchw2NhwcKernel(nh2nc_tensors, nc2nh_tensors, context_, nc2nh_name);
trans_kernels->push_back(nc2nh_kernel);
insert_primitive_.push_back(nc2nh_kernel->GetPrimitive());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -191,7 +191,11 @@ void NPUPassUtils::UpdateNC2NHTransNodePostKernel(kernel::LiteKernel *kernel, ke
// For post_kernel after trans, kernel in in_kernels should be replaced with trans_kernel.
auto post_in_kernels = post_kernel->in_kernels();
std::replace(post_in_kernels.begin(), post_in_kernels.end(), kernel, trans_kernel);
if (kernel == nullptr) {
post_in_kernels.push_back(trans_kernel);
} else {
std::replace(post_in_kernels.begin(), post_in_kernels.end(), kernel, trans_kernel);
}
post_kernel->set_in_kernels(post_in_kernels);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,11 +39,12 @@ int NPUTransformPass::InsertPreNodes(kernel::LiteKernel *kernel, std::vector<ker
MS_LOG(ERROR) << "New nchw tensor failed when inserting pre nhwc2nchw kernel.";
return RET_ERROR;
}
auto name = kernel->name() + "_pre_trans" + "_Nhwc2Nchw_" + std::to_string(total++);
tensor->set_tensor_name(name + "/output0");
std::vector<Tensor *> pre_trans_out_tensors = {tensor};
all_tensors_->push_back(pre_trans_out_tensors[0]);
// Create pre transform kernel: Nhwc2Nchw
auto name = kernel->name() + "_pre_trans" + "_Nhwc2Nchw_" + std::to_string(total++);
auto *trans_kernel =
NPUPassUtils::CreateNhwc2NchwKernel({kernel->in_tensors()[0]}, pre_trans_out_tensors, context_, name);
@ -124,6 +125,11 @@ int NPUTransformPass::Run() {
i++;
continue;
}
if (kernel->Type() == schema::PrimitiveType_Resize &&
kernel->in_tensors()[0]->Height() > kernel->out_tensors()[0]->Height()) {
i++;
continue;
}
// insert pre_kernels before kernel in vector
// modify loop index add (pre_kernels.size() + 1) to the post_kernels insert location
std::vector<kernel::LiteKernel *> pre_kernels;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,11 +26,15 @@ using mindspore::schema::PrimitiveType_Resize;
namespace mindspore::kernel {
int ResizeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
if (resize_parameter_->method_ != schema::ResizeMethod_LINEAR ||
resize_parameter_->method_ == schema::ResizeMethod_NEAREST) {
if (resize_parameter_->method_ != schema::ResizeMethod_LINEAR &&
resize_parameter_->method_ != schema::ResizeMethod_NEAREST) {
MS_LOG(WARNING) << "Unsupported resize method type:" << resize_parameter_->method_;
return RET_ERROR;
}
if (inputs[0]->Height() > outputs[0]->Height() || inputs[0]->Width() > outputs[0]->Width()) {
MS_LOG(WARNING) << "Npu resize does not support reduction.";
return RET_ERROR;
}
return RET_OK;
}
@ -55,7 +59,7 @@ int ResizeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con
op->set_input_size(*out_size);
op->set_attr_half_pixel_centers(resize_parameter_->preserve_aspect_ratio_);
op_ = op;
} else {
} else if (resize_parameter_->method_ == schema::ResizeMethod_NEAREST) {
auto op = new (std::nothrow) hiai::op::ResizeNearestNeighborV2(name_);
if (op == nullptr) {
MS_LOG(ERROR) << " op is nullptr.";
@ -66,6 +70,9 @@ int ResizeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con
op->set_input_x(*npu_inputs[0]);
op->set_input_size(*out_size);
op_ = op;
} else {
MS_LOG(WARNING) << "Unsupported resize method type:" << resize_parameter_->method_;
return RET_ERROR;
}
return RET_OK;
}