!19229 mindrt bug : input subgraph

Merge pull request !19229 from ling/bug
This commit is contained in:
i-robot 2021-07-01 11:10:11 +00:00 committed by Gitee
commit 53ef22a2e2
3 changed files with 42 additions and 13 deletions

View File

@ -73,19 +73,19 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_);
auto ret = InferSubGraphShape(kMainSubGraphIndex);
if (ret != RET_OK) {
int infershape_ret = InferSubGraphShape(kMainSubGraphIndex);
if (infershape_ret != RET_OK && infershape_ret != RET_INFER_INVALID) {
MS_LOG(ERROR) << "op infer shape failed.";
return ret;
return infershape_ret;
}
if (context_->enable_parallel_) {
if (context_->enable_parallel_ && infershape_ret != RET_INFER_INVALID) {
auto search_sub_graph =
SearchSubGraph(context_, src_model_, src_tensors_, &op_parameters_, &graph_output_node_indexes_);
search_sub_graph.SubGraphSplit();
}
ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, dst_kernels, nullptr, nullptr);
int ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, dst_kernels, nullptr, nullptr);
op_parameters_.clear();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Schedule main subgraph to kernels failed.";
@ -247,6 +247,7 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) {
MS_ASSERT(!src_model_->sub_graphs_.empty());
MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index);
auto subgraph = src_model_->sub_graphs_.at(subgraph_index);
int subgraph_infershape_ret = RET_OK;
for (auto node_index : subgraph->node_indices_) {
auto node = src_model_->all_nodes_[node_index];
MS_ASSERT(node != nullptr);
@ -260,12 +261,13 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) {
if (ret == RET_INFER_INVALID) {
MS_LOG(INFO) << "InferShape interrupted, name: " << node->name_ << ", type: " << PrimitiveTypeName(type)
<< ", set infer flag to false.";
subgraph_infershape_ret = RET_INFER_INVALID;
} else if (ret != RET_OK) {
MS_LOG(ERROR) << "InferShape failed, name: " << node->name_ << ", type: " << PrimitiveTypeName(type);
return RET_INFER_ERR;
}
}
return RET_OK;
return subgraph_infershape_ret;
}
namespace {

View File

@ -494,13 +494,18 @@ void SearchSubGraph::InsertNodeByMid(uint32_t node_index, Subgraph *subgraph) {
return;
}
subgraph->nodes_.push_back(node_index);
/* include this multy-in-unit in current subgraph */
std::vector<Subgraph> &subs = subs_iter->second;
std::set<uint32_t> subs_head;
/* insert nodes */
subgraph->nodes_.push_back(node_index);
for (Subgraph &sub : subs) {
subgraph->nodes_.insert(subgraph->nodes_.end(), sub.nodes_.begin(), sub.nodes_.end());
}
/* insert heads */
std::set<uint32_t> subs_head;
for (Subgraph &sub : subs) {
for (uint32_t head : sub.heads_) {
subs_head.insert(head);
}
@ -570,13 +575,17 @@ void SearchSubGraph::InitMiddleSubgraph(std::vector<uint32_t> *multy_in_nodes) {
std::vector<uint32_t> input_nodes = tensor->out_nodes_;
if (input_nodes.empty()) continue;
if (input_nodes.size() != 1) continue;
uint32_t input_node = input_nodes[0];
Subgraph sub;
sub.ends_.push_back(input_nodes[0]);
InsertNodeByMid(input_nodes[0], &sub);
sub.ends_.push_back(input_node);
InsertNodeByMid(input_node, &sub);
node_subs.push_back(sub);
}
node_sub_map_.insert(std::make_pair(node_index, node_subs));
if (!node_subs.empty()) {
node_sub_map_.insert(std::make_pair(node_index, node_subs));
}
}
return;
}
@ -712,6 +721,7 @@ void SearchSubGraph::SubgraphFusion(std::vector<Subgraph> *sub_graphs) {
new_sub.device_ = sub_graphs->at(sub1_index).device_;
new_sub.thread_ = sub_graphs->at(sub1_index).thread_;
new_sub.tid_ = sub_graphs->at(sub1_index).tid_;
new_sub.cost_ = sub_graphs->at(sub1_index).cost_ + sub_graphs->at(sub2_index).cost_;
Subgraph &sub1 = sub_graphs->at(sub1_index);
Subgraph &sub2 = sub_graphs->at(sub2_index);
@ -751,6 +761,10 @@ void SearchSubGraph::CalculateCostModel(std::vector<Subgraph> *sub_graphs) {
}
void SearchSubGraph::SubGraphSplitByOutput() {
if (output_nodes_->size() < kDefalutSubGraphSize) {
return;
}
InitSearchSubGraphByOutput();
CalculateCostModel(&sub_graphs_);
InitSubgraphRuntimeInfo(&sub_graphs_);
@ -758,6 +772,12 @@ void SearchSubGraph::SubGraphSplitByOutput() {
for (Subgraph &sub : sub_graphs_) {
CheckSubHeadEnd(&sub);
}
if (sub_graphs_.at(kDefalutFirstSubgraph).cost_.cost() < kMinSubgraphCost ||
sub_graphs_.at(kDefalutSecondSubgraph).cost_.cost() < kMinSubgraphCost) {
return;
}
ConvertSubGraphToModel(&sub_graphs_);
}
@ -782,7 +802,8 @@ void SearchSubGraph::SubGraphSplitByMiddle() {
/* redo cost-model and pre-set-info after optimize */
CalculateCostModel(&subgraphs);
if (subgraphs.at(0).cost_.cost() == 0 || subgraphs.at(1).cost_.cost() == 0) {
if (subgraphs.at(kDefalutFirstSubgraph).cost_.cost() < kMinSubgraphCost ||
subgraphs.at(kDefalutSecondSubgraph).cost_.cost() < kMinSubgraphCost) {
continue;
}
@ -982,6 +1003,9 @@ bool SearchSubGraph::ValidInParallel() {
if (front_node->quant_type_ != schema::QuantType_QUANT_NONE) {
return false;
}
if (major_dt_ == DT_NPU) {
return false;
}
return true;
}

View File

@ -32,8 +32,11 @@
namespace mindspore::lite {
constexpr int kDefaultDeviceType = -1;
constexpr int kDefalutSubGraphSize = 2;
constexpr int kDefalutFirstSubgraph = 0;
constexpr int kDefalutSecondSubgraph = 1;
constexpr int kDefaultInputs = 1;
constexpr int kMaxSubGraphCount = 20;
constexpr int kMinSubgraphCost = 50;
class SearchSubGraph {
enum TensorType { NORMAL, CONST, INPUT };