commit
494639ad8e
|
@ -25,7 +25,10 @@
|
|||
#include "schema/model_generated.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/fp32/winograd_utils.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
||||
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
|
@ -52,7 +55,7 @@ size_t WinogradConvDwMul() {
|
|||
}
|
||||
|
||||
void SearchSubGraph::dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
|
||||
std::vector<bool> *cor_group) {
|
||||
std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs) {
|
||||
if (i == n) {
|
||||
if (abs(except_value - current_sum) < *min_value) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
|
@ -65,12 +68,13 @@ void SearchSubGraph::dfs(int i, int n, int current_sum, int except_value, int *m
|
|||
|
||||
{
|
||||
tmp_group->at(i) = true;
|
||||
dfs(i + 1, n, current_sum + sub_graphs_[i].cost_.cost(), except_value, min_value, tmp_group, cor_group);
|
||||
int next_sum = current_sum + sub_graphs->at(i).cost_.cost();
|
||||
dfs(i + 1, n, next_sum, except_value, min_value, tmp_group, cor_group, sub_graphs);
|
||||
}
|
||||
|
||||
{
|
||||
tmp_group->at(i) = false;
|
||||
dfs(i + 1, n, current_sum, except_value, min_value, tmp_group, cor_group);
|
||||
dfs(i + 1, n, current_sum, except_value, min_value, tmp_group, cor_group, sub_graphs);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -141,10 +145,13 @@ const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph
|
|||
return std::move(primitive);
|
||||
}
|
||||
|
||||
void SearchSubGraph::ConvertSubGraphToModel() {
|
||||
void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
|
||||
if (sub_graphs->size() != 2) {
|
||||
return;
|
||||
}
|
||||
Model::SubGraph *main_graphs = model_->sub_graphs_.front();
|
||||
|
||||
for (Subgraph &subgraph : sub_graphs_) {
|
||||
for (Subgraph &subgraph : *sub_graphs) {
|
||||
if (subgraph.nodes_.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -209,7 +216,7 @@ void SearchSubGraph::ConvertSubGraphToModel() {
|
|||
model_->sub_graphs_.push_back(std::move(new_sub_graph));
|
||||
}
|
||||
|
||||
sub_graphs_.clear();
|
||||
sub_graphs->clear();
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -220,6 +227,9 @@ bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<u
|
|||
std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
|
||||
output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end());
|
||||
}
|
||||
if (output_indexes.size() == 1 && output_nodes.size() == 1) {
|
||||
return false;
|
||||
}
|
||||
for (uint32_t out_n : output_nodes) {
|
||||
if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) {
|
||||
return true;
|
||||
|
@ -228,6 +238,54 @@ bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<u
|
|||
return false;
|
||||
}
|
||||
|
||||
bool SearchSubGraph::IsNodeSubGraphHeadWithRoot(uint32_t node_index, const std::vector<uint32_t> &ready_nodes,
|
||||
uint32_t root_node_index) {
|
||||
std::vector<uint32_t> output_indexes = node_list_.at(node_index)->output_indices_;
|
||||
std::vector<uint32_t> output_nodes;
|
||||
for (uint32_t out_t : output_indexes) {
|
||||
std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
|
||||
output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end());
|
||||
}
|
||||
for (uint32_t out_n : output_nodes) {
|
||||
if (root_node_index != out_n) {
|
||||
if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void SearchSubGraph::SearchMultyInNodes(std::vector<uint32_t> *multy_in_nodes) {
|
||||
std::vector<uint32_t> all_main_sub_nodes = model_->sub_graphs_[0]->node_indices_;
|
||||
|
||||
for (size_t i = 0; i < all_main_sub_nodes.size(); i++) {
|
||||
uint32_t node_index = all_main_sub_nodes[i];
|
||||
Model::Node *node = node_list_[node_index];
|
||||
|
||||
if (IsPartialNode(node->primitive_)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int input_count = std::count_if(node->input_indices_.begin(), node->input_indices_.end(),
|
||||
[&](uint32_t in_tensor_index) { return tensors_[in_tensor_index].type_ != CONST; });
|
||||
|
||||
if (input_count > 1) {
|
||||
multy_in_nodes->push_back(node_index);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::RemoveConstNode(std::vector<uint32_t> *nodes) {
|
||||
for (int i = nodes->size() - 1; i >= 0; i--) {
|
||||
if (tensors_[nodes->at(i)].type_ == CONST) {
|
||||
VectorErase(nodes, nodes->at(i));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InsertNode(uint32_t index, Subgraph *subgraph) {
|
||||
if (subgraph->search_terminate_) {
|
||||
return;
|
||||
|
@ -239,12 +297,7 @@ void SearchSubGraph::InsertNode(uint32_t index, Subgraph *subgraph) {
|
|||
}
|
||||
|
||||
std::vector<uint32_t> input = node->input_indices_;
|
||||
/* remove const node */
|
||||
for (int i = input.size() - 1; i >= 0; i--) {
|
||||
if (tensors_[input[i]].type_ == CONST) {
|
||||
VectorErase(&input, input[i]);
|
||||
}
|
||||
}
|
||||
RemoveConstNode(&input);
|
||||
|
||||
/* all node_input is graph_input */
|
||||
for (size_t i = 0; i < input.size(); i++) {
|
||||
|
@ -286,8 +339,204 @@ void SearchSubGraph::InsertNode(uint32_t index, Subgraph *subgraph) {
|
|||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitSearchSubGraphByOutput() {
|
||||
void SearchSubGraph::OptimizeAfterFusion(std::vector<Subgraph> *sub_graphs, uint32_t root_node_index) {
|
||||
MS_ASSERT(sub_graphs->size() == 2);
|
||||
for (Subgraph &sub : *sub_graphs) {
|
||||
if (sub.nodes_.empty()) {
|
||||
return;
|
||||
}
|
||||
int head_size = sub.heads_.size();
|
||||
std::vector<uint32_t> used_heads;
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
uint32_t head_node_index = sub.heads_.at(i);
|
||||
if (std::find(used_heads.begin(), used_heads.end(), head_node_index) != used_heads.end()) {
|
||||
break;
|
||||
}
|
||||
std::vector<uint32_t> head_input_tensors = model_->all_nodes_[head_node_index]->input_indices_;
|
||||
RemoveConstNode(&head_input_tensors);
|
||||
if (head_input_tensors.size() != 1) continue;
|
||||
|
||||
std::vector<uint32_t> input_nodes = tensors_.at(head_input_tensors.at(0)).out_nodes_;
|
||||
if (input_nodes.size() != 1) continue;
|
||||
uint32_t input_node_index = input_nodes.at(0);
|
||||
|
||||
std::vector<uint32_t> input_tensors = model_->all_nodes_[input_node_index]->input_indices_;
|
||||
RemoveConstNode(&input_tensors);
|
||||
if (input_tensors.size() != 1) continue;
|
||||
|
||||
/* this node qualified:
|
||||
* 1. the only input node of current head node
|
||||
* 2. all output included in current subgraph
|
||||
* 3. one input-tensor */
|
||||
if (!IsNodeSubGraphHeadWithRoot(input_node_index, sub.nodes_, root_node_index)) {
|
||||
InsertHeadNode(input_node_index, &sub);
|
||||
used_heads.push_back(head_node_index); /* delete used head at end */
|
||||
}
|
||||
head_size = sub.heads_.size();
|
||||
}
|
||||
for (auto head_index : used_heads) {
|
||||
VectorErase(&sub.heads_, head_index);
|
||||
}
|
||||
|
||||
/* double check head-end node */
|
||||
/* head-end node may error after subgraph fusion */
|
||||
for (uint32_t head_node : sub.heads_) {
|
||||
if (std::find(sub.nodes_.begin(), sub.nodes_.end(), head_node) == sub.nodes_.end()) {
|
||||
VectorErase(&sub.nodes_, head_node);
|
||||
}
|
||||
}
|
||||
for (uint32_t end_node : sub.ends_) {
|
||||
if (std::find(sub.nodes_.begin(), sub.nodes_.end(), end_node) == sub.nodes_.end()) {
|
||||
VectorErase(&sub.ends_, end_node);
|
||||
}
|
||||
}
|
||||
|
||||
/* sort node index */
|
||||
std::sort(sub.nodes_.begin(), sub.nodes_.end());
|
||||
}
|
||||
}
|
||||
|
||||
void SearchSubGraph::InsertHeadNode(uint32_t head_node_index, Subgraph *subgraph) {
|
||||
Model::Node *node = node_list_.at(head_node_index);
|
||||
std::vector<uint32_t> head_node_inputs = node->input_indices_;
|
||||
RemoveConstNode(&head_node_inputs);
|
||||
|
||||
subgraph->nodes_.push_back(head_node_index);
|
||||
node_list_.at(head_node_index) = nullptr;
|
||||
|
||||
/* search for next node */
|
||||
size_t current_node_size = subgraph->nodes_.size();
|
||||
for (uint32_t in : head_node_inputs) {
|
||||
auto next_nodes = tensors_[in].out_nodes_;
|
||||
for (uint32_t next_node : next_nodes) {
|
||||
InsertNodeByMid(next_node, subgraph);
|
||||
}
|
||||
}
|
||||
|
||||
if (current_node_size == subgraph->nodes_.size()) {
|
||||
subgraph->heads_.push_back(head_node_index);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InsertNodeByMid(uint32_t node_index, Subgraph *subgraph) {
|
||||
Model::Node *node = node_list_.at(node_index);
|
||||
if (node == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto subs_iter = node_sub_map_.find(node_index);
|
||||
if (subs_iter != node_sub_map_.end()) {
|
||||
/* node is multy-in node , already searched before */
|
||||
|
||||
if (IsNodeSubGraphHead(node_index, subgraph->nodes_)) {
|
||||
/* this node can not be included in this subgraph */
|
||||
if (!subgraph->nodes_.empty()) subgraph->heads_.push_back(subgraph->nodes_.front());
|
||||
return;
|
||||
}
|
||||
|
||||
subgraph->nodes_.push_back(node_index);
|
||||
|
||||
/* include this multy-in-unit in current subgraph */
|
||||
std::vector<Subgraph> &subs = subs_iter->second;
|
||||
std::set<uint32_t> subs_head;
|
||||
for (Subgraph &sub : subs) {
|
||||
subgraph->nodes_.insert(subgraph->nodes_.end(), sub.nodes_.begin(), sub.nodes_.end());
|
||||
for (uint32_t head : sub.heads_) {
|
||||
subs_head.insert(head);
|
||||
}
|
||||
}
|
||||
|
||||
std::set<uint32_t> subs_head_baklist = subs_head;
|
||||
for (uint32_t head_node : subs_head) {
|
||||
std::vector<uint32_t> head_input_tensors = model_->all_nodes_[head_node]->input_indices_;
|
||||
RemoveConstNode(&head_input_tensors);
|
||||
if (head_input_tensors.size() != 1) continue;
|
||||
std::vector<uint32_t> input_nodes = tensors_.at(head_input_tensors.at(0)).out_nodes_;
|
||||
if (input_nodes.size() != 1) continue;
|
||||
|
||||
uint32_t input_node = input_nodes.at(0);
|
||||
if (!IsNodeSubGraphHead(input_node, subgraph->nodes_)) {
|
||||
InsertNodeByMid(input_node, subgraph);
|
||||
subs_head_baklist.erase(head_node);
|
||||
}
|
||||
}
|
||||
|
||||
/* stop search */
|
||||
for (auto head : subs_head_baklist) {
|
||||
subgraph->heads_.push_back(head);
|
||||
}
|
||||
node_sub_map_.erase(node_index);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> inputs = node->input_indices_;
|
||||
RemoveConstNode(&inputs);
|
||||
|
||||
if (IsNodeSubGraphHead(node_index, subgraph->nodes_)) {
|
||||
if (!subgraph->nodes_.empty()) {
|
||||
uint32_t current_node_list_head = subgraph->nodes_.front();
|
||||
if (std::find(subgraph->heads_.begin(), subgraph->heads_.end(), current_node_list_head) ==
|
||||
subgraph->heads_.end()) {
|
||||
subgraph->heads_.push_back(current_node_list_head);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
subgraph->nodes_.insert(subgraph->nodes_.begin(), node_index);
|
||||
node_list_.at(node_index) = nullptr;
|
||||
|
||||
/* search for next node */
|
||||
for (uint32_t in : inputs) {
|
||||
auto next_nodes = tensors_[in].out_nodes_;
|
||||
if (next_nodes.size() == 0) {
|
||||
if (!subgraph->nodes_.empty()) subgraph->heads_.push_back(subgraph->nodes_.front());
|
||||
} else {
|
||||
for (uint32_t next_node : next_nodes) {
|
||||
InsertNodeByMid(next_node, subgraph);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitMiddleSubgraph(std::vector<uint32_t> *multy_in_nodes) {
|
||||
for (uint32_t node_index : *multy_in_nodes) {
|
||||
std::vector<Subgraph> node_subs;
|
||||
Model::Node *node = node_list_[node_index];
|
||||
for (uint32_t input_tensor_index : node->input_indices_) {
|
||||
Tensor *tensor = &tensors_[input_tensor_index];
|
||||
if (tensor->type_ == CONST) continue;
|
||||
|
||||
std::vector<uint32_t> input_nodes = tensor->out_nodes_;
|
||||
Subgraph sub;
|
||||
sub.ends_.push_back(input_nodes[0]);
|
||||
InsertNodeByMid(input_nodes[0], &sub);
|
||||
node_subs.push_back(sub);
|
||||
}
|
||||
node_sub_map_.insert(std::make_pair(node_index, node_subs));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitSearchSubGraphByMiddle() {
|
||||
sub_graphs_.clear();
|
||||
node_list_ = model_->all_nodes_;
|
||||
|
||||
std::vector<uint32_t> multy_in_nodes;
|
||||
|
||||
SearchMultyInNodes(&multy_in_nodes);
|
||||
|
||||
InitMiddleSubgraph(&multy_in_nodes);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitSearchSubGraphByOutput() {
|
||||
sub_graphs_.clear();
|
||||
node_list_ = model_->all_nodes_;
|
||||
|
||||
for (uint32_t out : *output_nodes_) {
|
||||
Subgraph subgraph;
|
||||
|
||||
|
@ -298,8 +547,6 @@ void SearchSubGraph::InitSearchSubGraphByOutput() {
|
|||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitSearchSubGraphByMiddle() { return; }
|
||||
|
||||
void SearchSubGraph::InitSearchTensor() {
|
||||
tensors_.resize(model_->all_tensors_.size());
|
||||
|
||||
|
@ -333,23 +580,23 @@ void SearchSubGraph::InitSearchTensor() {
|
|||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitSubgraphRuntimeInfo() {
|
||||
void SearchSubGraph::InitSubgraphRuntimeInfo(std::vector<Subgraph> *sub_graphs) {
|
||||
std::vector<bool> tmp_group;
|
||||
std::vector<bool> cor_group;
|
||||
|
||||
tmp_group.resize(sub_graphs_.size());
|
||||
cor_group.resize(sub_graphs_.size());
|
||||
tmp_group.resize(sub_graphs->size());
|
||||
cor_group.resize(sub_graphs->size());
|
||||
|
||||
int except_value = total_cost_ * 0.5; /* major device responsible for 50% calculation */
|
||||
int min_value = INT32_MAX;
|
||||
|
||||
dfs(0, sub_graphs_.size(), 0, except_value, &min_value, &tmp_group, &cor_group);
|
||||
dfs(0, sub_graphs->size(), 0, except_value, &min_value, &tmp_group, &cor_group, sub_graphs);
|
||||
|
||||
/* make bigger half using major_dt_*/
|
||||
int true_value = 0;
|
||||
for (size_t i = 0; i < sub_graphs_.size(); i++) {
|
||||
for (size_t i = 0; i < sub_graphs->size(); i++) {
|
||||
if (cor_group.at(i)) {
|
||||
true_value += sub_graphs_[i].cost_.cost();
|
||||
true_value += sub_graphs->at(i).cost_.cost();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -357,34 +604,35 @@ void SearchSubGraph::InitSubgraphRuntimeInfo() {
|
|||
(void)std::transform(cor_group.begin(), cor_group.end(), cor_group.begin(), [](bool value) { return !value; });
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < sub_graphs_.size(); i++) {
|
||||
for (size_t i = 0; i < sub_graphs->size(); i++) {
|
||||
if (cor_group.at(i)) {
|
||||
sub_graphs_[i].device_ = major_dt_;
|
||||
sub_graphs_[i].thread_ = major_thread_;
|
||||
sub_graphs->at(i).device_ = major_dt_;
|
||||
sub_graphs->at(i).thread_ = major_thread_;
|
||||
sub_graphs->at(i).tid_ = 0;
|
||||
} else {
|
||||
sub_graphs_[i].device_ = minor_dt_;
|
||||
sub_graphs_[i].thread_ = minor_thread_;
|
||||
sub_graphs->at(i).device_ = minor_dt_;
|
||||
sub_graphs->at(i).thread_ = minor_thread_;
|
||||
sub_graphs->at(i).tid_ = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitMainGraphDevice() {
|
||||
void SearchSubGraph::InitMainGraphDevice(DeviceType dt) {
|
||||
Model::SubGraph *main_graph = model_->sub_graphs_.front();
|
||||
for (uint32_t node_index : main_graph->node_indices_) {
|
||||
Model::Node *node = model_->all_nodes_[node_index];
|
||||
node->device_type_ = major_dt_;
|
||||
node->device_type_ = dt;
|
||||
}
|
||||
}
|
||||
|
||||
void SearchSubGraph::SubgraphFusion() {
|
||||
while (sub_graphs_.size() > 2) {
|
||||
void SearchSubGraph::SubgraphFusion(std::vector<Subgraph> *sub_graphs) {
|
||||
while (sub_graphs->size() > 2) {
|
||||
size_t sub1_index = 0;
|
||||
size_t sub2_index = 0;
|
||||
bool is_found = false;
|
||||
for (sub1_index = 0; sub1_index < sub_graphs_.size(); sub1_index++) {
|
||||
for (size_t tmp2 = sub1_index + 1; tmp2 < sub_graphs_.size(); tmp2++) {
|
||||
if (sub_graphs_[sub1_index].device_ == sub_graphs_[tmp2].device_ &&
|
||||
sub_graphs_[sub1_index].thread_ == sub_graphs_[tmp2].thread_) {
|
||||
for (sub1_index = 0; sub1_index < sub_graphs->size(); sub1_index++) {
|
||||
for (size_t tmp2 = sub1_index + 1; tmp2 < sub_graphs->size(); tmp2++) {
|
||||
if (sub_graphs->at(sub1_index).tid_ == sub_graphs->at(tmp2).tid_) {
|
||||
sub2_index = tmp2;
|
||||
is_found = true;
|
||||
break;
|
||||
|
@ -397,54 +645,83 @@ void SearchSubGraph::SubgraphFusion() {
|
|||
MS_ASSERT(sub2_index > sub1_index); /* erase sub2 then sub1 */
|
||||
|
||||
Subgraph new_sub;
|
||||
new_sub.device_ = sub_graphs_[sub1_index].device_;
|
||||
new_sub.thread_ = sub_graphs_[sub1_index].thread_;
|
||||
new_sub.device_ = sub_graphs->at(sub1_index).device_;
|
||||
new_sub.thread_ = sub_graphs->at(sub1_index).thread_;
|
||||
new_sub.tid_ = sub_graphs->at(sub1_index).tid_;
|
||||
|
||||
Subgraph &sub1 = sub_graphs_[sub1_index];
|
||||
Subgraph &sub2 = sub_graphs_[sub2_index];
|
||||
Subgraph &sub1 = sub_graphs->at(sub1_index);
|
||||
Subgraph &sub2 = sub_graphs->at(sub2_index);
|
||||
new_sub.nodes_.insert(new_sub.nodes_.end(), sub1.nodes_.begin(), sub1.nodes_.end());
|
||||
new_sub.nodes_.insert(new_sub.nodes_.end(), sub2.nodes_.begin(), sub2.nodes_.end());
|
||||
new_sub.heads_.insert(new_sub.heads_.end(), sub1.heads_.begin(), sub1.heads_.end());
|
||||
new_sub.heads_.insert(new_sub.heads_.end(), sub2.heads_.begin(), sub2.heads_.end());
|
||||
new_sub.ends_.insert(new_sub.ends_.end(), sub1.ends_.begin(), sub1.ends_.end());
|
||||
new_sub.ends_.insert(new_sub.ends_.end(), sub2.ends_.begin(), sub2.ends_.end());
|
||||
sub_graphs_.erase(sub_graphs_.begin() + sub2_index);
|
||||
sub_graphs_.erase(sub_graphs_.begin() + sub1_index);
|
||||
sub_graphs_.insert(sub_graphs_.end(), std::move(new_sub));
|
||||
sub_graphs->erase(sub_graphs->begin() + sub2_index);
|
||||
sub_graphs->erase(sub_graphs->begin() + sub1_index);
|
||||
sub_graphs->insert(sub_graphs->end(), std::move(new_sub));
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::CalculateCostModel() {
|
||||
for (Subgraph &subgraph : sub_graphs_) {
|
||||
void SearchSubGraph::CalculateCostModel(std::vector<Subgraph> *sub_graphs) {
|
||||
total_cost_ = 0;
|
||||
for (Subgraph &subgraph : *sub_graphs) {
|
||||
subgraph.cost_.empty();
|
||||
std::vector<uint32_t> nodes = subgraph.nodes_;
|
||||
for (uint32_t node_index : nodes) {
|
||||
CostModel cost;
|
||||
cost.io_cost_ = 0;
|
||||
cost.mul_cost_ = 1;
|
||||
|
||||
Model::Node *node = model_->all_nodes_[node_index];
|
||||
if (GetPrimitiveType(node->primitive_) == schema::PrimitiveType_Conv2DFusion) {
|
||||
CostModel conv_cost = CalculateConv2DFusion(node);
|
||||
subgraph.cost_ = subgraph.cost_ + conv_cost;
|
||||
total_cost_ += conv_cost.cost();
|
||||
continue;
|
||||
cost = CalculateConv2DFusion(node);
|
||||
}
|
||||
|
||||
subgraph.cost_ = subgraph.cost_ + cost;
|
||||
total_cost_ += cost.cost();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SearchSubGraph::SubGraphSplitByOutput() {
|
||||
InitSearchSubGraphByOutput();
|
||||
CalculateCostModel();
|
||||
InitSubgraphRuntimeInfo();
|
||||
SubgraphFusion();
|
||||
ConvertSubGraphToModel();
|
||||
CalculateCostModel(&sub_graphs_);
|
||||
InitSubgraphRuntimeInfo(&sub_graphs_);
|
||||
SubgraphFusion(&sub_graphs_);
|
||||
ConvertSubGraphToModel(&sub_graphs_);
|
||||
}
|
||||
|
||||
void SearchSubGraph::SubGraphSplitByMiddle() {
|
||||
InitSearchSubGraphByMiddle();
|
||||
CalculateCostModel();
|
||||
InitSubgraphRuntimeInfo();
|
||||
SubgraphFusion();
|
||||
ConvertSubGraphToModel();
|
||||
for (auto map : node_sub_map_) {
|
||||
std::vector<Subgraph> &subgraphs = map.second;
|
||||
|
||||
CalculateCostModel(&subgraphs);
|
||||
InitSubgraphRuntimeInfo(&subgraphs);
|
||||
SubgraphFusion(&subgraphs);
|
||||
|
||||
MS_ASSERT(subgraphs.size() == 2);
|
||||
if (std::any_of(subgraphs.begin(), subgraphs.end(), [&](Subgraph &sub) { return sub.nodes_.empty(); })) {
|
||||
continue;
|
||||
}
|
||||
|
||||
OptimizeAfterFusion(&subgraphs, map.first);
|
||||
|
||||
/* redo cost-model and pre-set-info after optimize */
|
||||
CalculateCostModel(&subgraphs);
|
||||
if (subgraphs.at(0).cost_.cost() == 0 || subgraphs.at(1).cost_.cost() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
InitSubgraphRuntimeInfo(&subgraphs);
|
||||
|
||||
InitMainGraphDevice(DT_CPU);
|
||||
|
||||
ConvertSubGraphToModel(&subgraphs);
|
||||
}
|
||||
}
|
||||
|
||||
SearchSubGraph::SearchSubGraph(const InnerContext *context, Model *model, std::vector<lite::Tensor *> *src_tensors,
|
||||
|
@ -554,7 +831,7 @@ void SearchSubGraph::InitSearchParallelSubGraph() {
|
|||
void SearchSubGraph::SubGraphSplitByOffLineParallel() {
|
||||
MS_LOG(DEBUG) << "start to split offline parallel subgraph";
|
||||
InitSearchParallelSubGraph();
|
||||
ConvertSubGraphToModel();
|
||||
ConvertSubGraphToModel(&sub_graphs_);
|
||||
InitMainGraphDevice();
|
||||
MS_LOG(DEBUG) << "end to split offline parallel subgraph";
|
||||
}
|
||||
|
@ -564,6 +841,7 @@ void SearchSubGraph::SubGraphSplit() {
|
|||
SubGraphSplitByOffLineParallel();
|
||||
} else {
|
||||
SubGraphSplitByOutput();
|
||||
SubGraphSplitByMiddle();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include <stack>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include "include/model.h"
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/lite_model.h"
|
||||
|
@ -55,6 +57,10 @@ class SearchSubGraph {
|
|||
return result;
|
||||
}
|
||||
int cost() { return io_cost_ + mul_cost_; }
|
||||
void empty() {
|
||||
io_cost_ = 0;
|
||||
mul_cost_ = 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct Subgraph {
|
||||
|
@ -65,6 +71,7 @@ class SearchSubGraph {
|
|||
DeviceType device_;
|
||||
size_t thread_;
|
||||
CostModel cost_;
|
||||
uint32_t tid_; /* 1 or 2 */
|
||||
};
|
||||
|
||||
public:
|
||||
|
@ -78,29 +85,43 @@ class SearchSubGraph {
|
|||
private:
|
||||
void SubGraphSplitByOutput();
|
||||
void InitSearchSubGraphByOutput();
|
||||
void InsertNode(uint32_t index, Subgraph *subgraph);
|
||||
|
||||
private:
|
||||
void SubGraphSplitByMiddle();
|
||||
void InitSearchSubGraphByMiddle();
|
||||
void SearchMultyInNodes(std::vector<uint32_t> *multy_in_nodes);
|
||||
void InitMiddleSubgraph(std::vector<uint32_t> *multy_in_nodes);
|
||||
void InsertNodeByMid(uint32_t node_index, Subgraph *subgraph);
|
||||
void InsertHeadNode(uint32_t index, Subgraph *subgraph);
|
||||
void OptimizeAfterFusion(std::vector<Subgraph> *sub_graphs, uint32_t root_node_index);
|
||||
std::unordered_map<uint32_t, std::vector<Subgraph>> node_sub_map_;
|
||||
|
||||
private:
|
||||
void SubGraphSplitByOffLineParallel();
|
||||
|
||||
private:
|
||||
void RemoveConstNode(std::vector<uint32_t> *nodes);
|
||||
void InitSearchTensor();
|
||||
void InitSearchParallelSubGraph();
|
||||
void ConvertSubGraphToModel();
|
||||
void InsertNode(uint32_t index, Subgraph *subgraph);
|
||||
void InitMainGraphDevice(DeviceType dt = DT_CPU);
|
||||
|
||||
void InitSubgraphRuntimeInfo(std::vector<Subgraph> *sub_graphs);
|
||||
void SubgraphFusion(std::vector<Subgraph> *sub_graphs);
|
||||
void CalculateCostModel(std::vector<Subgraph> *sub_graphs);
|
||||
void ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs);
|
||||
|
||||
private:
|
||||
void InsertParallelNode(uint32_t index, Subgraph *subgraph);
|
||||
bool IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes);
|
||||
bool IsNodeSubGraphHeadWithRoot(uint32_t node_index, const std::vector<uint32_t> &ready_nodes,
|
||||
uint32_t root_node_index);
|
||||
const schema::Primitive *CreatePartialPrimitive(int64_t subgraph_index);
|
||||
void InitSubgraphRuntimeInfo();
|
||||
void SubgraphFusion();
|
||||
void InitMainGraphDevice();
|
||||
void CalculateCostModel();
|
||||
|
||||
private:
|
||||
CostModel CalculateConv2DFusion(Model::Node *node);
|
||||
void dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
|
||||
std::vector<bool> *cor_group);
|
||||
std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs);
|
||||
|
||||
private:
|
||||
std::vector<size_t> *output_nodes_ = nullptr;
|
||||
|
|
|
@ -2,4 +2,5 @@
|
|||
# model run both CPU-CPU & GPU-CPU
|
||||
# model_file ### accuracy_limit ### enable_fp16(true or false)
|
||||
mtk_model_normalize_object_scene_ps_20200519_f32.tflite;0.5;false
|
||||
hiai_cv_poseEstimation.tflite;0.5;false
|
||||
# end
|
Loading…
Reference in New Issue