insert transform pass

This commit is contained in:
zhaozhenlong 2021-02-07 10:43:03 +08:00
parent f9f24ca94d
commit 03ecf5be34
6 changed files with 224 additions and 100 deletions

View File

@ -246,31 +246,31 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
int NPUFusionPass::Run() {
for (size_t i = 0; i < kernels->size(); i++) {
auto kernel = (*kernels)[i];
if (NPUPassUtils::IsNchw2Nhwc(kernel) || NPUPassUtils::IsNhwc2Nchw(kernel)) {
if (CheckFormatFusion(kernel)) {
i--;
FormatFusion(kernel);
if (CheckFusion(kernel)) {
switch (kernel->Type()) {
case schema::PrimitiveType_Concat:
i -= kernel->in_kernels().size();
ConcatFusion(kernel);
continue;
case schema::PrimitiveType_Add:
case schema::PrimitiveType_Activation:
case schema::PrimitiveType_Eltwise:
i -= kernel->in_kernels().size();
CommonFusion(kernel);
continue;
default:
continue;
}
continue;
}
if (!CheckFusion(kernel)) {
continue;
}
switch (kernel->Type()) {
case schema::PrimitiveType_Concat:
i -= kernel->in_kernels().size();
ConcatFusion(kernel);
continue;
case schema::PrimitiveType_Add:
case schema::PrimitiveType_Activation:
case schema::PrimitiveType_Eltwise:
i -= kernel->in_kernels().size();
CommonFusion(kernel);
continue;
default:
continue;
}
}
for (size_t i = 0; i < kernels->size(); ++i) {
auto kernel = (*kernels)[i];
if (CheckFormatFusion(kernel)) {
i--;
FormatFusion(kernel);
}
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -20,31 +20,81 @@
namespace mindspore::lite {
using kernel::KERNEL_ARCH::kNPU;
enum InsertState { InsertNone, PreInsert, PostInsert };
enum InsertState { InsertNone, PreInsert, PostInsert, BothInsert };
std::set<mindspore::schema::PrimitiveType> npu_insert_nodes = {schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
schema::PrimitiveType_Eltwise,
schema::PrimitiveType_Activation};
// this pass goal is to minimize subgraphs generated
// by inserting nchw2nhwc or nhwc2nchw before or after the operator (e.g. concat, add, etc..) together with
// fusion pass. If transpose inserted are more than half of input output, we will insert remaining input
// output with transpose and hopefully do a fusion pass. Otherwise, we don't insert anything.
//
// Typically concat accept output from nchw2nhwc, we fill other input with nh2nc and nc2nh so that inputs to concat are
// format same and then fusion all nchw2nhwc op.
// e.g.
// original (conv->nchw2nhwc, add(format nhwc)) -> concat-> (nhwc2nchw->conv)
// current pass (conv->nchw2nhwc, add->nhwc2nchw->nchw2nhwc) -> concat -> (nhwc2nchw->conv)
// fusion pass (conv, add->nhwc2nchw) -> concat -> conv
// original 2 cpusubgraph, after 2 pass, only 1 cpu subgraph
//
// node:
// Such ops require inputs all have same format, could be nchw or nhwc or other format.
// Their inputs outputs may not be 4d, or are already format ok,
// so we won't insert nc2nh or nh2nc when op's in kernels and out kernels contains no nc2nh or nh2nc.
// This pass should be run after npu_transform_pass, which insert transpose for nchw-input-limited op like conv2d.
int GetInsertState(kernel::LiteKernel *kernel) {
// filter out irrelevant kernel
if (npu_insert_nodes.find(kernel->Type()) == npu_insert_nodes.end()) {
return InsertNone;
}
auto pre_flag = std::all_of(kernel->in_kernels().begin(), kernel->in_kernels().end(),
[](const kernel::LiteKernel *kernel) { return NPUPassUtils::IsNchw2Nhwc(kernel); });
auto post_flag = std::all_of(kernel->out_kernels().begin(), kernel->out_kernels().end(),
[](const kernel::LiteKernel *kernel) { return NPUPassUtils::IsNhwc2Nchw(kernel); });
if (pre_flag && !post_flag) {
return PostInsert;
// current kernel is target kernel
// use out kernels to count how many out lines from current kernel
size_t in_out_tensor_num = kernel->in_tensors().size() + kernel->out_kernels().size();
size_t transpose_input_num = 0;
size_t transpose_output_num = 0;
bool need_pre_insert = false;
bool need_post_insert = false;
// count number of input tensor from nc2nh and output tensor to nh2nc
for (size_t i = 0; i < kernel->in_tensors().size(); ++i) {
auto in_kernel = NPUPassUtils::KernelInputFromKernel(kernel, i);
if (NPUPassUtils::IsNchw2Nhwc(in_kernel)) {
transpose_input_num++;
} else {
need_pre_insert = true;
}
}
if (!pre_flag && post_flag) {
for (const auto out_kernel : kernel->out_kernels()) {
if (NPUPassUtils::IsNhwc2Nchw(out_kernel)) {
transpose_output_num++;
} else {
need_post_insert = true;
}
}
// won't insert any thing if num of transpose tensor is smaller than half of total input output.
// won't insert if total input output are all transpose tensor, the fusion pass will handle this.
size_t transpose_tensor_num = transpose_input_num + transpose_output_num;
if (transpose_tensor_num <= in_out_tensor_num / 2 || transpose_tensor_num == in_out_tensor_num) {
return InsertNone;
}
if (need_pre_insert && !need_post_insert) {
return PreInsert;
}
if (need_pre_insert && need_post_insert) {
return BothInsert;
}
if (!need_pre_insert && need_post_insert) {
return PostInsert;
}
return InsertNone;
}
int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel,
std::vector<kernel::LiteKernel *> *trans_kernels) {
size_t post_input_index, std::vector<kernel::LiteKernel *> *trans_kernels) {
// Kernel and post_kernel can't be nullptr at the same time.
std::string kernel_name;
Tensor *in_tensor = nullptr;
@ -54,7 +104,7 @@ int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteK
if (post_kernel != nullptr) {
out_kernels.push_back(post_kernel);
kernel_name = post_kernel->name() + "_pre";
in_tensor = post_kernel->in_tensors()[0];
in_tensor = post_kernel->in_tensors().at(post_input_index);
}
std::vector<kernel::LiteKernel *> in_kernels;
// If kernel equals nullptr, post_kernel is the input of whole graph.
@ -99,87 +149,134 @@ int NPUInsertTransformPass::InsertNode(kernel::LiteKernel *kernel, kernel::LiteK
}
if (post_kernel != nullptr) {
NPUPassUtils::UpdateNC2NHTransNodePostKernel(kernel, nc2nh_kernel, post_kernel);
} else {
// post_kernel nullptr mean output, we remain graph output tensor name unchanged
auto graph_output_name = in_tensor->tensor_name();
in_tensor->set_tensor_name(graph_output_name + "_before_" + name_);
nc2nh_tensor->set_tensor_name(graph_output_name);
}
return RET_OK;
}
int NPUInsertTransformPass::InsertForInputTensor(kernel::LiteKernel *kernel, size_t in_tensor_index,
kernel::LiteKernel *pre_kernel,
std::vector<kernel::LiteKernel *> *trans_kernels) {
// insert transpose nodes before target ops
return InsertNode(pre_kernel, kernel, in_tensor_index, trans_kernels);
}
int NPUInsertTransformPass::InsertForOutputTensor(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel,
size_t post_in_tensor_index,
std::vector<kernel::LiteKernel *> *trans_kernels) {
// insert transpose nodes after target ops
return InsertNode(kernel, post_kernel, post_in_tensor_index, trans_kernels);
}
int NPUInsertTransformPass::InsertPreNodes(kernel::LiteKernel *kernel,
std::vector<kernel::LiteKernel *> *trans_kernels) {
if (kernel->in_kernels().size() != kernel->in_tensors().size()) {
MS_LOG(DEBUG) << "The input tensors of kernel may be the input of whole graph or const tensor.";
return RET_OK;
}
if (kernel->in_kernels().empty()) {
auto ret = InsertNode(nullptr, kernel, trans_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() << " failed.";
return RET_ERROR;
}
}
for (auto in_kernel : kernel->in_kernels()) {
if (NPUPassUtils::IsNchw2Nhwc(in_kernel)) {
int ret = RET_OK;
for (size_t i = 0; i < kernel->in_tensors().size(); ++i) {
auto pre_kernel = NPUPassUtils::KernelInputFromKernel(kernel, i);
if (NPUPassUtils::IsNchw2Nhwc(pre_kernel)) {
continue;
}
auto ret = InsertNode(in_kernel, kernel, trans_kernels);
// if this tensor is input of graph, pre_kernel is nullptr.
ret = InsertForInputTensor(kernel, i, pre_kernel, trans_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() << " failed.";
return RET_ERROR;
return ret;
}
}
return RET_OK;
return ret;
}
int NPUInsertTransformPass::InsertPostNodes(kernel::LiteKernel *kernel,
std::vector<kernel::LiteKernel *> *trans_kernels) {
if (kernel->out_kernels().empty()) {
auto ret = InsertNode(kernel, nullptr, trans_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed.";
return RET_ERROR;
}
}
for (auto out_kernel : kernel->out_kernels()) {
if (NPUPassUtils::IsNhwc2Nchw(out_kernel)) {
int ret = RET_OK;
for (const auto post_kernel : kernel->out_kernels()) {
if (NPUPassUtils::IsNhwc2Nchw(post_kernel)) {
continue;
}
auto ret = InsertNode(kernel, out_kernel, trans_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed.";
auto post_kernel_in_tensors = post_kernel->in_tensors();
// kernel's out tensor is one of post_kernel's input tensor
auto it = std::find(post_kernel_in_tensors.begin(), post_kernel_in_tensors.end(), kernel->out_tensors().at(0));
if (it == post_kernel_in_tensors.end()) {
return RET_ERROR;
}
size_t input_index = it - post_kernel_in_tensors.begin();
ret = InsertForOutputTensor(kernel, post_kernel, input_index, trans_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed.";
return ret;
}
}
return RET_OK;
if (kernel->out_tensors().size() > kernel->out_kernels().size()) {
// kernel out is graph output
ret = InsertForOutputTensor(kernel, nullptr, 0, trans_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed.";
return ret;
}
}
return ret;
}
int NPUInsertTransformPass::Run() {
std::vector<kernel::LiteKernel *> insert_kernels;
for (size_t i = 0; i < all_kernels_->size(); i++) {
auto kernel = (*all_kernels_)[i];
if (kernel->desc().arch != kNPU) {
continue;
}
auto insert_state = GetInsertState(kernel);
insert_kernels.clear();
// If the every output kernel is nhwc2nchw, insert
// modify loop index add post_kernels.size() to the next kernel in the origin vector
if (insert_state == PreInsert) {
std::vector<kernel::LiteKernel *> pre_kernels;
auto ret = InsertPreNodes(kernel, &pre_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name() << " failed.";
return RET_ERROR;
switch (insert_state) {
case PreInsert: {
auto ret = InsertPreNodes(kernel, &insert_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name()
<< " failed.";
return RET_ERROR;
}
all_kernels_->insert(all_kernels_->begin() + i, insert_kernels.begin(), insert_kernels.end());
i += insert_kernels.size();
break;
}
all_kernels_->insert(all_kernels_->begin() + i, pre_kernels.begin(), pre_kernels.end());
i += pre_kernels.size();
}
case PostInsert: {
auto ret = InsertPostNodes(kernel, &insert_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed.";
return RET_ERROR;
}
all_kernels_->insert(all_kernels_->begin() + i + 1, insert_kernels.begin(), insert_kernels.end());
i += insert_kernels.size();
break;
}
case BothInsert: {
auto ret = InsertPreNodes(kernel, &insert_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel before kernel " << kernel->name()
<< " failed.";
return RET_ERROR;
}
all_kernels_->insert(all_kernels_->begin() + i, insert_kernels.begin(), insert_kernels.end());
i += insert_kernels.size();
if (insert_state == PostInsert) {
std::vector<kernel::LiteKernel *> post_kernels;
auto ret = InsertPostNodes(kernel, &post_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed.";
return RET_ERROR;
insert_kernels.clear();
ret = InsertPostNodes(kernel, &insert_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert nhwc2nchw kernel and nchw2nhwc kernel after kernel " << kernel->name() << " failed.";
return RET_ERROR;
}
all_kernels_->insert(all_kernels_->begin() + i + 1, insert_kernels.begin(), insert_kernels.end());
i += insert_kernels.size();
break;
}
all_kernels_->insert(all_kernels_->begin() + i + 1, post_kernels.begin(), post_kernels.end());
i += post_kernels.size();
default:
MS_LOG(DEBUG) << "Insert Nothing on kernel " << kernel->name();
}
}
return RET_OK;

View File

@ -45,8 +45,13 @@ class NPUInsertTransformPass : public NPUBasePass {
int InsertPostNodes(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *trans_kernels);
int InsertNode(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel,
int InsertNode(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, size_t post_input_index,
std::vector<kernel::LiteKernel *> *trans_kernels);
int InsertForInputTensor(kernel::LiteKernel *kernel, size_t in_tensor_index, kernel::LiteKernel *pre_kernel,
std::vector<kernel::LiteKernel *> *trans_kernels);
int InsertForOutputTensor(kernel::LiteKernel *kernel, kernel::LiteKernel *post_kernel, size_t post_in_tensor_index,
std::vector<kernel::LiteKernel *> *trans_kernels);
private:
int total = 0;

View File

@ -172,32 +172,33 @@ void NPUPassUtils::UpdateNC2NHPostKernelInTensors(kernel::LiteKernel *kernel, ke
void NPUPassUtils::UpdateNC2NHTransNodePostKernel(kernel::LiteKernel *kernel, kernel::LiteKernel *trans_kernel,
kernel::LiteKernel *post_kernel) {
// For post_kernel after trans, kernel should be replaced with trans_kernel.
auto post_in_tensors = post_kernel->in_tensors();
if (kernel == nullptr) {
post_in_tensors[0] = trans_kernel->out_tensors()[0];
} else {
for (size_t i = 0; i < post_in_tensors.size(); i++) {
if (post_in_tensors[i] == kernel->out_tensors()[0]) {
post_in_tensors[i] = trans_kernel->out_tensors()[0];
break;
}
}
}
post_kernel->set_in_tensors(post_in_tensors);
// The input tensor should be replaced with the output tensor of trans_kernel.
std::vector<kernel::LiteKernel *> post_in_kernels = post_kernel->in_kernels();
for (size_t i = 0; i < post_in_kernels.size(); i++) {
if (post_in_kernels[i] == kernel) {
post_in_kernels[i] = trans_kernel;
auto post_in_tensors = post_kernel->in_tensors();
Tensor *old_in_tensor = nullptr;
// find out which input tensor of post_kernel should be updated
for (size_t i = 0; i < post_in_tensors.size(); ++i) {
if (KernelInputFromKernel(post_kernel, i) == kernel) {
old_in_tensor = post_in_tensors.at(i);
break;
}
}
if (old_in_tensor == nullptr) {
MS_LOG(WARNING) << "Could not find in tensor index";
return;
}
std::replace(post_in_tensors.begin(), post_in_tensors.end(), old_in_tensor, trans_kernel->out_tensors().at(0));
post_kernel->set_in_tensors(post_in_tensors);
// For post_kernel after trans, kernel in in_kernels should be replaced with trans_kernel.
auto post_in_kernels = post_kernel->in_kernels();
std::replace(post_in_kernels.begin(), post_in_kernels.end(), kernel, trans_kernel);
post_kernel->set_in_kernels(post_in_kernels);
}
bool NPUPassUtils::IsNhwc2Nchw(const kernel::LiteKernel *kernel) {
if (kernel == nullptr) {
return false;
}
if (kernel->Type() != schema::PrimitiveType_Transpose) {
return false;
}
@ -215,6 +216,9 @@ bool NPUPassUtils::IsNhwc2Nchw(const kernel::LiteKernel *kernel) {
}
bool NPUPassUtils::IsNchw2Nhwc(const kernel::LiteKernel *kernel) {
if (kernel == nullptr) {
return false;
}
if (kernel->Type() != schema::PrimitiveType_Transpose) {
return false;
}
@ -230,5 +234,22 @@ bool NPUPassUtils::IsNchw2Nhwc(const kernel::LiteKernel *kernel) {
}
return false;
}
kernel::LiteKernel *NPUPassUtils::KernelInputFromKernel(const kernel::LiteKernel *kernel, size_t in_tensor_index) {
// given kernel and input tensor index, get which kernel output this tensor.
// If input tensor is graph input, return nullptr.
if (kernel == nullptr) {
return nullptr;
}
auto tensor = kernel->in_tensors().at(in_tensor_index);
auto in_kernels = kernel->in_kernels();
auto output_contain = [tensor](const kernel::LiteKernel *kernel) {
auto out_tensors = kernel->out_tensors();
return std::find(out_tensors.begin(), out_tensors.end(), tensor) != out_tensors.end();
};
auto it = std::find_if(in_kernels.begin(), in_kernels.end(), output_contain);
if (it == in_kernels.end()) {
return nullptr;
}
return *it;
}
} // namespace mindspore::lite

View File

@ -52,6 +52,7 @@ class NPUPassUtils {
static bool IsNhwc2Nchw(const kernel::LiteKernel *kernel);
static bool IsNchw2Nhwc(const kernel::LiteKernel *kernel);
static kernel::LiteKernel *KernelInputFromKernel(const kernel::LiteKernel *kernel, size_t in_tensor_index);
private:
static PrimitiveC *CreateTransposePrimitive();

View File

@ -26,7 +26,7 @@ crnn_lite_lstm_v2.onnx;32,32,32,1 0.3
psenet_lite_mbv2.onnx;1,32,32,3 0.6
super-resolution-10.onnx;1,224,224,1 4.5
tinyyolov2-8.onnx;1,416,416,3 5.5
ml_2012_ocr_cn.onnx -1
#ml_2012_ocr_cn.onnx -1
#ml_2012_ocr_cn_noLSTM.onnx 1
candy-9.onnx 5
mosaic-9.onnx 4