!19229 mindrt bug : input subgraph
Merge pull request !19229 from ling/bug
This commit is contained in:
commit
53ef22a2e2
|
@ -73,19 +73,19 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
|||
|
||||
this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_);
|
||||
|
||||
auto ret = InferSubGraphShape(kMainSubGraphIndex);
|
||||
if (ret != RET_OK) {
|
||||
int infershape_ret = InferSubGraphShape(kMainSubGraphIndex);
|
||||
if (infershape_ret != RET_OK && infershape_ret != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "op infer shape failed.";
|
||||
return ret;
|
||||
return infershape_ret;
|
||||
}
|
||||
|
||||
if (context_->enable_parallel_) {
|
||||
if (context_->enable_parallel_ && infershape_ret != RET_INFER_INVALID) {
|
||||
auto search_sub_graph =
|
||||
SearchSubGraph(context_, src_model_, src_tensors_, &op_parameters_, &graph_output_node_indexes_);
|
||||
search_sub_graph.SubGraphSplit();
|
||||
}
|
||||
|
||||
ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, dst_kernels, nullptr, nullptr);
|
||||
int ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, dst_kernels, nullptr, nullptr);
|
||||
op_parameters_.clear();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Schedule main subgraph to kernels failed.";
|
||||
|
@ -247,6 +247,7 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) {
|
|||
MS_ASSERT(!src_model_->sub_graphs_.empty());
|
||||
MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index);
|
||||
auto subgraph = src_model_->sub_graphs_.at(subgraph_index);
|
||||
int subgraph_infershape_ret = RET_OK;
|
||||
for (auto node_index : subgraph->node_indices_) {
|
||||
auto node = src_model_->all_nodes_[node_index];
|
||||
MS_ASSERT(node != nullptr);
|
||||
|
@ -260,12 +261,13 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) {
|
|||
if (ret == RET_INFER_INVALID) {
|
||||
MS_LOG(INFO) << "InferShape interrupted, name: " << node->name_ << ", type: " << PrimitiveTypeName(type)
|
||||
<< ", set infer flag to false.";
|
||||
subgraph_infershape_ret = RET_INFER_INVALID;
|
||||
} else if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InferShape failed, name: " << node->name_ << ", type: " << PrimitiveTypeName(type);
|
||||
return RET_INFER_ERR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
return subgraph_infershape_ret;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -494,13 +494,18 @@ void SearchSubGraph::InsertNodeByMid(uint32_t node_index, Subgraph *subgraph) {
|
|||
return;
|
||||
}
|
||||
|
||||
subgraph->nodes_.push_back(node_index);
|
||||
|
||||
/* include this multy-in-unit in current subgraph */
|
||||
std::vector<Subgraph> &subs = subs_iter->second;
|
||||
std::set<uint32_t> subs_head;
|
||||
|
||||
/* insert nodes */
|
||||
subgraph->nodes_.push_back(node_index);
|
||||
for (Subgraph &sub : subs) {
|
||||
subgraph->nodes_.insert(subgraph->nodes_.end(), sub.nodes_.begin(), sub.nodes_.end());
|
||||
}
|
||||
|
||||
/* insert heads */
|
||||
std::set<uint32_t> subs_head;
|
||||
for (Subgraph &sub : subs) {
|
||||
for (uint32_t head : sub.heads_) {
|
||||
subs_head.insert(head);
|
||||
}
|
||||
|
@ -570,13 +575,17 @@ void SearchSubGraph::InitMiddleSubgraph(std::vector<uint32_t> *multy_in_nodes) {
|
|||
|
||||
std::vector<uint32_t> input_nodes = tensor->out_nodes_;
|
||||
if (input_nodes.empty()) continue;
|
||||
if (input_nodes.size() != 1) continue;
|
||||
uint32_t input_node = input_nodes[0];
|
||||
|
||||
Subgraph sub;
|
||||
sub.ends_.push_back(input_nodes[0]);
|
||||
InsertNodeByMid(input_nodes[0], &sub);
|
||||
sub.ends_.push_back(input_node);
|
||||
InsertNodeByMid(input_node, &sub);
|
||||
node_subs.push_back(sub);
|
||||
}
|
||||
node_sub_map_.insert(std::make_pair(node_index, node_subs));
|
||||
if (!node_subs.empty()) {
|
||||
node_sub_map_.insert(std::make_pair(node_index, node_subs));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -712,6 +721,7 @@ void SearchSubGraph::SubgraphFusion(std::vector<Subgraph> *sub_graphs) {
|
|||
new_sub.device_ = sub_graphs->at(sub1_index).device_;
|
||||
new_sub.thread_ = sub_graphs->at(sub1_index).thread_;
|
||||
new_sub.tid_ = sub_graphs->at(sub1_index).tid_;
|
||||
new_sub.cost_ = sub_graphs->at(sub1_index).cost_ + sub_graphs->at(sub2_index).cost_;
|
||||
|
||||
Subgraph &sub1 = sub_graphs->at(sub1_index);
|
||||
Subgraph &sub2 = sub_graphs->at(sub2_index);
|
||||
|
@ -751,6 +761,10 @@ void SearchSubGraph::CalculateCostModel(std::vector<Subgraph> *sub_graphs) {
|
|||
}
|
||||
|
||||
void SearchSubGraph::SubGraphSplitByOutput() {
|
||||
if (output_nodes_->size() < kDefalutSubGraphSize) {
|
||||
return;
|
||||
}
|
||||
|
||||
InitSearchSubGraphByOutput();
|
||||
CalculateCostModel(&sub_graphs_);
|
||||
InitSubgraphRuntimeInfo(&sub_graphs_);
|
||||
|
@ -758,6 +772,12 @@ void SearchSubGraph::SubGraphSplitByOutput() {
|
|||
for (Subgraph &sub : sub_graphs_) {
|
||||
CheckSubHeadEnd(&sub);
|
||||
}
|
||||
|
||||
if (sub_graphs_.at(kDefalutFirstSubgraph).cost_.cost() < kMinSubgraphCost ||
|
||||
sub_graphs_.at(kDefalutSecondSubgraph).cost_.cost() < kMinSubgraphCost) {
|
||||
return;
|
||||
}
|
||||
|
||||
ConvertSubGraphToModel(&sub_graphs_);
|
||||
}
|
||||
|
||||
|
@ -782,7 +802,8 @@ void SearchSubGraph::SubGraphSplitByMiddle() {
|
|||
|
||||
/* redo cost-model and pre-set-info after optimize */
|
||||
CalculateCostModel(&subgraphs);
|
||||
if (subgraphs.at(0).cost_.cost() == 0 || subgraphs.at(1).cost_.cost() == 0) {
|
||||
if (subgraphs.at(kDefalutFirstSubgraph).cost_.cost() < kMinSubgraphCost ||
|
||||
subgraphs.at(kDefalutSecondSubgraph).cost_.cost() < kMinSubgraphCost) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -982,6 +1003,9 @@ bool SearchSubGraph::ValidInParallel() {
|
|||
if (front_node->quant_type_ != schema::QuantType_QUANT_NONE) {
|
||||
return false;
|
||||
}
|
||||
if (major_dt_ == DT_NPU) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -32,8 +32,11 @@
|
|||
namespace mindspore::lite {
|
||||
constexpr int kDefaultDeviceType = -1;
|
||||
constexpr int kDefalutSubGraphSize = 2;
|
||||
constexpr int kDefalutFirstSubgraph = 0;
|
||||
constexpr int kDefalutSecondSubgraph = 1;
|
||||
constexpr int kDefaultInputs = 1;
|
||||
constexpr int kMaxSubGraphCount = 20;
|
||||
constexpr int kMinSubgraphCost = 50;
|
||||
class SearchSubGraph {
|
||||
enum TensorType { NORMAL, CONST, INPUT };
|
||||
|
||||
|
|
Loading…
Reference in New Issue