fix multi conv node

This commit is contained in:
z00512249 2021-06-23 12:03:40 +08:00
parent ac6d75b803
commit 45544cd82e
7 changed files with 171 additions and 59 deletions

View File

@ -50,7 +50,7 @@ void InnerContext::SetContextDevice(const Context *context) {
for (auto &device_ctx : context->device_list_) { for (auto &device_ctx : context->device_list_) {
// npu/gpu server would use one core so we don't bind core to avoid competition. // npu/gpu server would use one core so we don't bind core to avoid competition.
// If user does not set npu/gpu device, we still bind core. // If user does not set npu/gpu device, we still bind core.
if (device_ctx.device_type_ == DT_CPU && (isUserSetNPU || isUserSetGPU)) { if (device_ctx.device_type_ == DT_CPU && (isUserSetNPU || (isUserSetGPU && !enable_parallel_))) {
auto cpu_ctx = device_ctx; auto cpu_ctx = device_ctx;
cpu_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; cpu_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
this->device_list_.push_back(cpu_ctx); this->device_list_.push_back(cpu_ctx);

View File

@ -862,14 +862,14 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
} }
int LiteSession::InitGPURuntime() { int LiteSession::InitGPURuntime() {
CpuBindMode cpu_bind_mode = this->context_->device_list_.front().device_info_.cpu_device_info_.cpu_bind_mode_;
ActorThreadPool *thread_pool = this->context_->thread_pool(); ActorThreadPool *thread_pool = this->context_->thread_pool();
if (thread_pool == nullptr) { if (thread_pool == nullptr) {
MS_LOG(ERROR) << "thread pool is nullptr"; MS_LOG(ERROR) << "thread pool is nullptr";
is_running_.store(false); is_running_.store(false);
return RET_NULL_PTR; return RET_NULL_PTR;
} }
// Setting the binding core will affect the opencl drive scheduling. thread_pool->SetProcessAffinity(static_cast<BindMode>(cpu_bind_mode));
thread_pool->SetProcessAffinity(static_cast<BindMode>(NO_BIND));
#if GPU_OPENCL #if GPU_OPENCL
if (this->context_->IsGpuEnabled()) { if (this->context_->IsGpuEnabled()) {
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper(); opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper();
@ -911,6 +911,8 @@ int LiteSession::InitGPURuntime() {
} }
} }
#endif #endif
// Setting the binding core will affect the opencl drive scheduling.
thread_pool->SetProcessAffinity(static_cast<BindMode>(NO_BIND));
return RET_OK; return RET_OK;
} }
} // namespace lite } // namespace lite

View File

@ -66,8 +66,8 @@ int SplitWithOverlapBaseCPUKernel::Init() { return RET_OK; }
int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; } int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; }
int SplitWithOverlapBaseCPUKernel::Split(int task_id) { int SplitWithOverlapBaseCPUKernel::Split(int task_id) {
DoSplitWithOverlap(input_ptr_, output_ptr_.data(), param_->num_split_, split_dim_size_, element_bytes_, DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_,
outer_total_dim_, inner_stride_, start_indices_.data(), end_indices_.data()); inner_stride_, start_indices_.data(), end_indices_.data());
return RET_OK; return RET_OK;
} }
@ -117,7 +117,7 @@ int SplitWithOverlapBaseCPUKernel::Run() {
inner_stride_ *= input_shape[i]; inner_stride_ *= input_shape[i];
} }
auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, context_->thread_num_); auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, param_->num_split_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]"; MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]";
return RET_ERROR; return RET_ERROR;

View File

@ -26,12 +26,12 @@
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
#include "nnacl/fp32/winograd_utils.h" #include "nnacl/fp32/winograd_utils.h"
#include "nnacl/pooling_parameter.h" #include "nnacl/pooling_parameter.h"
#include "include/model.h"
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
#include "nnacl/fp32/conv_depthwise_fp32.h" #include "nnacl/fp32/conv_depthwise_fp32.h"
#endif #endif
namespace mindspore::lite { namespace mindspore::lite {
size_t CommConvMul(std::vector<int> weight_shape, std::vector<int> output_shape) { size_t CommConvMul(std::vector<int> weight_shape, std::vector<int> output_shape) {
size_t cost = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * weight_shape[1] * size_t cost = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * weight_shape[1] *
weight_shape[2] * weight_shape[3]; weight_shape[2] * weight_shape[3];
@ -54,6 +54,69 @@ size_t WinogradConvDwMul() {
return 0; return 0;
} }
bool IsOfflineParallelNode(const void *node_primitive, int node_device_type) {
if (node_primitive == nullptr) {
return false;
}
return (GetPrimitiveType(node_primitive) == schema::PrimitiveType_Conv2DFusion) &&
(node_device_type != kDefaultDeviceType);
}
void SearchSubGraph::UpdateOfflineParallelFlag() {
if (model_ == nullptr) {
offline_parallel_enable_ = false;
return;
}
// visited whole models to find any conv && depthwise conv have been set to device type
offline_parallel_enable_ =
std::any_of(this->model_->all_nodes_.begin(), this->model_->all_nodes_.end(),
[&](lite::Model::Node *node) { return IsOfflineParallelNode(node->primitive_, node->device_type_); });
}
bool SearchSubGraph::CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs) {
if (subgraphs.size() != kDefalutSubGraphSize) {
return false;
}
for (const auto &sub_graph : subgraphs) {
auto heads = sub_graph.heads_;
auto ends = sub_graph.ends_;
if (heads.size() != kDefaultInputs || ends.size() != kDefaultInputs) {
return false;
}
auto head_node = model_->all_nodes_.at(heads.front());
auto end_node = model_->all_nodes_.at(ends.front());
if (!IsOfflineParallelNode(head_node->primitive_, head_node->device_type_) ||
!IsOfflineParallelNode(end_node->primitive_, end_node->device_type_)) {
return false;
}
// 1. check head_node's input is SplitOverlap node
for (const auto &input : head_node->input_indices_) {
if (tensors_.at(input).type_ == CONST) {
continue;
}
auto input_node_index = tensors_.at(input).out_nodes_.front();
if (GetPrimitiveType(model_->all_nodes_.at(input_node_index)->primitive_) !=
schema::PrimitiveType_SplitWithOverlap) {
return false;
}
}
// 2. check end_node's output is concat node
for (const auto &output : end_node->output_indices_) {
if (tensors_.at(output).type_ == CONST) {
continue;
}
auto output_node_index = tensors_.at(output).in_nodes_.front();
if (GetPrimitiveType(model_->all_nodes_.at(output_node_index)->primitive_) != schema::PrimitiveType_Concat) {
return false;
}
}
}
return true;
}
void SearchSubGraph::dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group, void SearchSubGraph::dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs) { std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs) {
if (i == n) { if (i == n) {
@ -146,7 +209,7 @@ const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph
} }
void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) { void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
if (sub_graphs->size() != 2) { if (sub_graphs->size() != kDefalutSubGraphSize) {
return; return;
} }
Model::SubGraph *main_graphs = model_->sub_graphs_.front(); Model::SubGraph *main_graphs = model_->sub_graphs_.front();
@ -166,8 +229,7 @@ void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
MS_LOG(ERROR) << "New sub graph failed!"; MS_LOG(ERROR) << "New sub graph failed!";
return; return;
} }
new_sub_graph->name_ = "Subgraph-split-" + std::to_string(new_sub_index); new_sub_graph->name_ = "subgraph-split-" + std::to_string(new_sub_index);
Model::Node *new_partial_node = new (std::nothrow) Model::Node(); Model::Node *new_partial_node = new (std::nothrow) Model::Node();
if (new_partial_node == nullptr) { if (new_partial_node == nullptr) {
MS_LOG(ERROR) << "New partial node failed!"; MS_LOG(ERROR) << "New partial node failed!";
@ -175,6 +237,16 @@ void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
return; return;
} }
new_partial_node->name_ = "Partial-subgraph-split-" + std::to_string(new_sub_index); new_partial_node->name_ = "Partial-subgraph-split-" + std::to_string(new_sub_index);
if (device_type == DT_CPU) {
new_partial_node->name_ = "cpu_" + new_partial_node->name_;
} else if (device_type == DT_GPU) {
new_partial_node->name_ = "gpu_" + new_partial_node->name_;
} else if (device_type == DT_NPU) {
new_partial_node->name_ = "npu_" + new_partial_node->name_;
} else {
new_partial_node->name_ = "unknow_" + new_partial_node->name_;
}
new_partial_node->node_type_ = mindspore::lite::NodeType_ValueNode; new_partial_node->node_type_ = mindspore::lite::NodeType_ValueNode;
new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index); new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index);
@ -221,7 +293,7 @@ void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
} }
bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes) { bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes) {
std::vector<uint32_t> output_indexes = node_list_.at(node_index)->output_indices_; std::vector<uint32_t> output_indexes = model_->all_nodes_.at(node_index)->output_indices_;
std::vector<uint32_t> output_nodes; std::vector<uint32_t> output_nodes;
for (uint32_t out_t : output_indexes) { for (uint32_t out_t : output_indexes) {
std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_; std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
@ -703,7 +775,7 @@ void SearchSubGraph::SubGraphSplitByMiddle() {
InitSubgraphRuntimeInfo(&subgraphs); InitSubgraphRuntimeInfo(&subgraphs);
SubgraphFusion(&subgraphs); SubgraphFusion(&subgraphs);
MS_ASSERT(subgraphs.size() == 2); MS_ASSERT(subgraphs.size() == kDefalutSubGraphSize);
if (std::any_of(subgraphs.begin(), subgraphs.end(), [&](Subgraph &sub) { return sub.nodes_.empty(); })) { if (std::any_of(subgraphs.begin(), subgraphs.end(), [&](Subgraph &sub) { return sub.nodes_.empty(); })) {
continue; continue;
} }
@ -724,6 +796,60 @@ void SearchSubGraph::SubGraphSplitByMiddle() {
} }
} }
void SearchSubGraph::SubGraphSplitByOffLineParallel() {
sub_graphs_.clear();
node_list_ = model_->all_nodes_;
std::vector<uint32_t> multy_in_nodes;
SearchMultyInNodes(&multy_in_nodes);
for (uint32_t node_index : multy_in_nodes) {
Model::Node *node = node_list_[node_index];
if (GetPrimitiveType(node->primitive_) != schema::PrimitiveType_Concat) {
continue;
}
std::vector<Subgraph> node_subs;
for (uint32_t input_tensor_index : node->input_indices_) {
Tensor *tensor = &tensors_[input_tensor_index];
if (tensor->type_ == CONST) continue;
std::vector<uint32_t> input_nodes = tensor->out_nodes_;
Subgraph sub;
sub.ends_.push_back(input_nodes[0]);
InsertNodeByMid(input_nodes[0], &sub);
node_subs.push_back(sub);
}
node_sub_map_.insert(std::make_pair(node_index, node_subs));
}
for (auto map : node_sub_map_) {
std::vector<Subgraph> &subgraphs = map.second;
if (std::any_of(subgraphs.begin(), subgraphs.end(), [&](Subgraph &sub) { return sub.nodes_.empty(); })) {
continue;
}
if (!CheckIsParallelSubGraph(subgraphs)) {
continue;
}
// init graph device type
for (auto &subgraph : subgraphs) {
uint32_t head_node_index = subgraph.heads_.front();
subgraph.device_ = static_cast<lite::DeviceType>(model_->all_nodes_.at(head_node_index)->device_type_);
if (subgraph.device_ == DT_GPU) {
subgraph.thread_ = major_thread_;
subgraph.tid_ = 0;
} else {
subgraph.thread_ = minor_thread_;
subgraph.tid_ = 1;
}
}
ConvertSubGraphToModel(&subgraphs);
}
InitMainGraphDevice(DT_CPU);
}
SearchSubGraph::SearchSubGraph(const InnerContext *context, Model *model, std::vector<lite::Tensor *> *src_tensors, SearchSubGraph::SearchSubGraph(const InnerContext *context, Model *model, std::vector<lite::Tensor *> *src_tensors,
const std::map<int, OpParameter *> *op_parameters, std::vector<size_t> *output_nodes) const std::map<int, OpParameter *> *op_parameters, std::vector<size_t> *output_nodes)
: output_nodes_(output_nodes), context_(context), src_tensors_(src_tensors), op_parameters_(op_parameters) { : output_nodes_(output_nodes), context_(context), src_tensors_(src_tensors), op_parameters_(op_parameters) {
@ -756,29 +882,20 @@ void SearchSubGraph::InsertParallelNode(uint32_t index, Subgraph *subgraph) {
return; return;
} }
if (subgraph->search_terminate_) { if (subgraph->search_terminate_) {
return; if (!subgraph->nodes_.empty()) {
sub_graphs_.push_back(std::move(*subgraph));
}
Subgraph new_graph;
subgraph = &new_graph;
} }
Model::Node *node = node_list_[index]; Model::Node *node = node_list_[index];
// has been searched // has been searched
if (node == nullptr) { if (node == nullptr) {
return; return;
} }
// just deal with parallel target node
std::vector<uint32_t> input = node->input_indices_; // if current node is parallel target node
/* remove const node */ if (IsOfflineParallelNode(node->primitive_, node->device_type_)) {
for (int i = static_cast<int>(input.size()) - 1; i >= 0; i--) {
if (tensors_[input[i]].type_ == CONST) {
VectorErase(&input, input[i]);
}
}
// search to graph to graph input , terminate it.
if (std::any_of(input.begin(), input.end(), [&](int input_index) { return tensors_[input_index].type_ == INPUT; })) {
subgraph->search_terminate_ = true;
return;
}
// if current node is no parallel target node, just judge terminate or continue
if (GetPrimitiveType(node->primitive_) == schema::PrimitiveType_Conv2DFusion &&
node->device_type_ != kDefaultDeviceType) {
// first searched // first searched
if (subgraph->nodes_.empty()) { if (subgraph->nodes_.empty()) {
subgraph->device_ = static_cast<DeviceType>(node->device_type_); subgraph->device_ = static_cast<DeviceType>(node->device_type_);
@ -788,28 +905,28 @@ void SearchSubGraph::InsertParallelNode(uint32_t index, Subgraph *subgraph) {
return; return;
} }
} }
if (IsNodeSubGraphHead(index, subgraph->nodes_)) {
if (subgraph->nodes_.empty()) {
subgraph->search_terminate_ = true;
return;
}
subgraph->heads_.push_back(subgraph->nodes_.front());
return;
}
// for offline parallel target subgraph only has one end
if (subgraph->ends_.empty()) {
subgraph->ends_.push_back(index);
}
subgraph->nodes_.insert(subgraph->nodes_.begin(), index); subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
node_list_[index] = nullptr; node_list_[index] = nullptr;
} else { } else {
if (!subgraph->nodes_.empty()) { subgraph->search_terminate_ = true;
return; }
// just deal with parallel target node
std::vector<uint32_t> input = node->input_indices_;
/* remove const node */
for (int i = static_cast<int>(input.size()) - 1; i >= 0; i--) {
if (tensors_[input[i]].type_ == CONST) {
VectorErase(&input, input[i]);
} }
} }
// search to graph to graph input , terminate it.
if (std::any_of(input.begin(), input.end(), [&](int input_index) { return tensors_[input_index].type_ == INPUT; })) {
subgraph->search_terminate_ = true;
return;
}
// search for next nodes // search for next nodes
for (uint32_t next : input) { for (uint32_t next : input) {
auto next_nodes = tensors_[next].out_nodes_; auto next_nodes = tensors_[next].out_nodes_;
@ -828,15 +945,8 @@ void SearchSubGraph::InitSearchParallelSubGraph() {
} }
} }
void SearchSubGraph::SubGraphSplitByOffLineParallel() {
MS_LOG(DEBUG) << "start to split offline parallel subgraph";
InitSearchParallelSubGraph();
ConvertSubGraphToModel(&sub_graphs_);
InitMainGraphDevice();
MS_LOG(DEBUG) << "end to split offline parallel subgraph";
}
void SearchSubGraph::SubGraphSplit() { void SearchSubGraph::SubGraphSplit() {
UpdateOfflineParallelFlag();
if (offline_parallel_enable_) { if (offline_parallel_enable_) {
SubGraphSplitByOffLineParallel(); SubGraphSplitByOffLineParallel();
} else { } else {

View File

@ -31,6 +31,8 @@
namespace mindspore::lite { namespace mindspore::lite {
constexpr int kDefaultDeviceType = -1; constexpr int kDefaultDeviceType = -1;
constexpr int kDefalutSubGraphSize = 2;
constexpr int kDefaultInputs = 1;
class SearchSubGraph { class SearchSubGraph {
enum TensorType { NORMAL, CONST, INPUT }; enum TensorType { NORMAL, CONST, INPUT };
@ -122,6 +124,8 @@ class SearchSubGraph {
CostModel CalculateConv2DFusion(Model::Node *node); CostModel CalculateConv2DFusion(Model::Node *node);
void dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group, void dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs); std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs);
void UpdateOfflineParallelFlag();
bool CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs);
private: private:
std::vector<size_t> *output_nodes_ = nullptr; std::vector<size_t> *output_nodes_ = nullptr;

View File

@ -291,8 +291,6 @@ AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePt
concate_cnode->set_scope(conv_cnode->scope()); concate_cnode->set_scope(conv_cnode->scope());
std::vector<AnfNodePtr> outputs; std::vector<AnfNodePtr> outputs;
GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs); GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs);
// only support split_overlap node implementation, to split sub_graph for runtime
concate_cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(lite::DT_CPU)));
return concate_cnode; return concate_cnode;
} }
@ -337,8 +335,6 @@ void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNo
ptr_list.push_back(value_node); ptr_list.push_back(value_node);
} }
split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list)); split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
// only support split_overlap node implementation
split_cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(lite::DT_CPU)));
} }
} // namespace opt } // namespace opt

View File

@ -89,12 +89,12 @@ bool MultiConvSplit::CheckSplitValid() {
return false; return false;
} }
int64_t split_axis_value_0 = UP_DIV(split_info_.ori_split_axis_value * visited_block, total_block_count); int64_t split_axis_value_0 = UP_DIV(split_info_.ori_split_axis_value * visited_block, total_block_count);
if (split_axis_value_0 >= split_info_.ori_split_axis_value) { if (split_axis_value_0 > split_info_.ori_split_axis_value) {
return false; return false;
} }
int64_t split_axis_value_1 = split_info_.ori_split_axis_value - split_axis_value_0; int64_t split_axis_value_1 = split_info_.ori_split_axis_value - split_axis_value_0;
split_axis_value_1 += (split_info_.extend_top.back() + split_info_.extend_bottom.back()); split_axis_value_1 += (split_info_.extend_top.back() + split_info_.extend_bottom.back());
return split_axis_value_1 < split_info_.ori_split_axis_value; return split_axis_value_1 <= split_info_.ori_split_axis_value;
} }
int MultiConvSplit::GetMultiConvNodes(const AnfNodePtr &conv_node) { int MultiConvSplit::GetMultiConvNodes(const AnfNodePtr &conv_node) {