fix multi conv node
This commit is contained in:
parent
ac6d75b803
commit
45544cd82e
|
@ -50,7 +50,7 @@ void InnerContext::SetContextDevice(const Context *context) {
|
||||||
for (auto &device_ctx : context->device_list_) {
|
for (auto &device_ctx : context->device_list_) {
|
||||||
// npu/gpu server would use one core so we don't bind core to avoid competition.
|
// npu/gpu server would use one core so we don't bind core to avoid competition.
|
||||||
// If user does not set npu/gpu device, we still bind core.
|
// If user does not set npu/gpu device, we still bind core.
|
||||||
if (device_ctx.device_type_ == DT_CPU && (isUserSetNPU || isUserSetGPU)) {
|
if (device_ctx.device_type_ == DT_CPU && (isUserSetNPU || (isUserSetGPU && !enable_parallel_))) {
|
||||||
auto cpu_ctx = device_ctx;
|
auto cpu_ctx = device_ctx;
|
||||||
cpu_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
|
cpu_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
|
||||||
this->device_list_.push_back(cpu_ctx);
|
this->device_list_.push_back(cpu_ctx);
|
||||||
|
|
|
@ -862,14 +862,14 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
|
||||||
}
|
}
|
||||||
|
|
||||||
int LiteSession::InitGPURuntime() {
|
int LiteSession::InitGPURuntime() {
|
||||||
|
CpuBindMode cpu_bind_mode = this->context_->device_list_.front().device_info_.cpu_device_info_.cpu_bind_mode_;
|
||||||
ActorThreadPool *thread_pool = this->context_->thread_pool();
|
ActorThreadPool *thread_pool = this->context_->thread_pool();
|
||||||
if (thread_pool == nullptr) {
|
if (thread_pool == nullptr) {
|
||||||
MS_LOG(ERROR) << "thread pool is nullptr";
|
MS_LOG(ERROR) << "thread pool is nullptr";
|
||||||
is_running_.store(false);
|
is_running_.store(false);
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
// Setting the binding core will affect the opencl drive scheduling.
|
thread_pool->SetProcessAffinity(static_cast<BindMode>(cpu_bind_mode));
|
||||||
thread_pool->SetProcessAffinity(static_cast<BindMode>(NO_BIND));
|
|
||||||
#if GPU_OPENCL
|
#if GPU_OPENCL
|
||||||
if (this->context_->IsGpuEnabled()) {
|
if (this->context_->IsGpuEnabled()) {
|
||||||
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper();
|
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper();
|
||||||
|
@ -911,6 +911,8 @@ int LiteSession::InitGPURuntime() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
// Setting the binding core will affect the opencl drive scheduling.
|
||||||
|
thread_pool->SetProcessAffinity(static_cast<BindMode>(NO_BIND));
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
|
|
|
@ -66,8 +66,8 @@ int SplitWithOverlapBaseCPUKernel::Init() { return RET_OK; }
|
||||||
int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; }
|
int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; }
|
||||||
|
|
||||||
int SplitWithOverlapBaseCPUKernel::Split(int task_id) {
|
int SplitWithOverlapBaseCPUKernel::Split(int task_id) {
|
||||||
DoSplitWithOverlap(input_ptr_, output_ptr_.data(), param_->num_split_, split_dim_size_, element_bytes_,
|
DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_,
|
||||||
outer_total_dim_, inner_stride_, start_indices_.data(), end_indices_.data());
|
inner_stride_, start_indices_.data(), end_indices_.data());
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -117,7 +117,7 @@ int SplitWithOverlapBaseCPUKernel::Run() {
|
||||||
inner_stride_ *= input_shape[i];
|
inner_stride_ *= input_shape[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, context_->thread_num_);
|
auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, param_->num_split_);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]";
|
MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -26,12 +26,12 @@
|
||||||
#include "src/ops/populate/populate_register.h"
|
#include "src/ops/populate/populate_register.h"
|
||||||
#include "nnacl/fp32/winograd_utils.h"
|
#include "nnacl/fp32/winograd_utils.h"
|
||||||
#include "nnacl/pooling_parameter.h"
|
#include "nnacl/pooling_parameter.h"
|
||||||
|
#include "include/model.h"
|
||||||
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
||||||
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
|
||||||
size_t CommConvMul(std::vector<int> weight_shape, std::vector<int> output_shape) {
|
size_t CommConvMul(std::vector<int> weight_shape, std::vector<int> output_shape) {
|
||||||
size_t cost = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * weight_shape[1] *
|
size_t cost = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * weight_shape[1] *
|
||||||
weight_shape[2] * weight_shape[3];
|
weight_shape[2] * weight_shape[3];
|
||||||
|
@ -54,6 +54,69 @@ size_t WinogradConvDwMul() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsOfflineParallelNode(const void *node_primitive, int node_device_type) {
|
||||||
|
if (node_primitive == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return (GetPrimitiveType(node_primitive) == schema::PrimitiveType_Conv2DFusion) &&
|
||||||
|
(node_device_type != kDefaultDeviceType);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SearchSubGraph::UpdateOfflineParallelFlag() {
|
||||||
|
if (model_ == nullptr) {
|
||||||
|
offline_parallel_enable_ = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// visited whole models to find any conv && depthwise conv have been set to device type
|
||||||
|
offline_parallel_enable_ =
|
||||||
|
std::any_of(this->model_->all_nodes_.begin(), this->model_->all_nodes_.end(),
|
||||||
|
[&](lite::Model::Node *node) { return IsOfflineParallelNode(node->primitive_, node->device_type_); });
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SearchSubGraph::CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs) {
|
||||||
|
if (subgraphs.size() != kDefalutSubGraphSize) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &sub_graph : subgraphs) {
|
||||||
|
auto heads = sub_graph.heads_;
|
||||||
|
auto ends = sub_graph.ends_;
|
||||||
|
if (heads.size() != kDefaultInputs || ends.size() != kDefaultInputs) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto head_node = model_->all_nodes_.at(heads.front());
|
||||||
|
auto end_node = model_->all_nodes_.at(ends.front());
|
||||||
|
if (!IsOfflineParallelNode(head_node->primitive_, head_node->device_type_) ||
|
||||||
|
!IsOfflineParallelNode(end_node->primitive_, end_node->device_type_)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. check head_node's input is SplitOverlap node
|
||||||
|
for (const auto &input : head_node->input_indices_) {
|
||||||
|
if (tensors_.at(input).type_ == CONST) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto input_node_index = tensors_.at(input).out_nodes_.front();
|
||||||
|
if (GetPrimitiveType(model_->all_nodes_.at(input_node_index)->primitive_) !=
|
||||||
|
schema::PrimitiveType_SplitWithOverlap) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. check end_node's output is concat node
|
||||||
|
for (const auto &output : end_node->output_indices_) {
|
||||||
|
if (tensors_.at(output).type_ == CONST) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto output_node_index = tensors_.at(output).in_nodes_.front();
|
||||||
|
if (GetPrimitiveType(model_->all_nodes_.at(output_node_index)->primitive_) != schema::PrimitiveType_Concat) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void SearchSubGraph::dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
|
void SearchSubGraph::dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
|
||||||
std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs) {
|
std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs) {
|
||||||
if (i == n) {
|
if (i == n) {
|
||||||
|
@ -146,7 +209,7 @@ const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph
|
||||||
}
|
}
|
||||||
|
|
||||||
void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
|
void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
|
||||||
if (sub_graphs->size() != 2) {
|
if (sub_graphs->size() != kDefalutSubGraphSize) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Model::SubGraph *main_graphs = model_->sub_graphs_.front();
|
Model::SubGraph *main_graphs = model_->sub_graphs_.front();
|
||||||
|
@ -166,8 +229,7 @@ void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
|
||||||
MS_LOG(ERROR) << "New sub graph failed!";
|
MS_LOG(ERROR) << "New sub graph failed!";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
new_sub_graph->name_ = "Subgraph-split-" + std::to_string(new_sub_index);
|
new_sub_graph->name_ = "subgraph-split-" + std::to_string(new_sub_index);
|
||||||
|
|
||||||
Model::Node *new_partial_node = new (std::nothrow) Model::Node();
|
Model::Node *new_partial_node = new (std::nothrow) Model::Node();
|
||||||
if (new_partial_node == nullptr) {
|
if (new_partial_node == nullptr) {
|
||||||
MS_LOG(ERROR) << "New partial node failed!";
|
MS_LOG(ERROR) << "New partial node failed!";
|
||||||
|
@ -175,6 +237,16 @@ void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
new_partial_node->name_ = "Partial-subgraph-split-" + std::to_string(new_sub_index);
|
new_partial_node->name_ = "Partial-subgraph-split-" + std::to_string(new_sub_index);
|
||||||
|
if (device_type == DT_CPU) {
|
||||||
|
new_partial_node->name_ = "cpu_" + new_partial_node->name_;
|
||||||
|
} else if (device_type == DT_GPU) {
|
||||||
|
new_partial_node->name_ = "gpu_" + new_partial_node->name_;
|
||||||
|
|
||||||
|
} else if (device_type == DT_NPU) {
|
||||||
|
new_partial_node->name_ = "npu_" + new_partial_node->name_;
|
||||||
|
} else {
|
||||||
|
new_partial_node->name_ = "unknow_" + new_partial_node->name_;
|
||||||
|
}
|
||||||
new_partial_node->node_type_ = mindspore::lite::NodeType_ValueNode;
|
new_partial_node->node_type_ = mindspore::lite::NodeType_ValueNode;
|
||||||
new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index);
|
new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index);
|
||||||
|
|
||||||
|
@ -221,7 +293,7 @@ void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes) {
|
bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes) {
|
||||||
std::vector<uint32_t> output_indexes = node_list_.at(node_index)->output_indices_;
|
std::vector<uint32_t> output_indexes = model_->all_nodes_.at(node_index)->output_indices_;
|
||||||
std::vector<uint32_t> output_nodes;
|
std::vector<uint32_t> output_nodes;
|
||||||
for (uint32_t out_t : output_indexes) {
|
for (uint32_t out_t : output_indexes) {
|
||||||
std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
|
std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
|
||||||
|
@ -703,7 +775,7 @@ void SearchSubGraph::SubGraphSplitByMiddle() {
|
||||||
InitSubgraphRuntimeInfo(&subgraphs);
|
InitSubgraphRuntimeInfo(&subgraphs);
|
||||||
SubgraphFusion(&subgraphs);
|
SubgraphFusion(&subgraphs);
|
||||||
|
|
||||||
MS_ASSERT(subgraphs.size() == 2);
|
MS_ASSERT(subgraphs.size() == kDefalutSubGraphSize);
|
||||||
if (std::any_of(subgraphs.begin(), subgraphs.end(), [&](Subgraph &sub) { return sub.nodes_.empty(); })) {
|
if (std::any_of(subgraphs.begin(), subgraphs.end(), [&](Subgraph &sub) { return sub.nodes_.empty(); })) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -724,6 +796,60 @@ void SearchSubGraph::SubGraphSplitByMiddle() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SearchSubGraph::SubGraphSplitByOffLineParallel() {
|
||||||
|
sub_graphs_.clear();
|
||||||
|
node_list_ = model_->all_nodes_;
|
||||||
|
|
||||||
|
std::vector<uint32_t> multy_in_nodes;
|
||||||
|
|
||||||
|
SearchMultyInNodes(&multy_in_nodes);
|
||||||
|
|
||||||
|
for (uint32_t node_index : multy_in_nodes) {
|
||||||
|
Model::Node *node = node_list_[node_index];
|
||||||
|
if (GetPrimitiveType(node->primitive_) != schema::PrimitiveType_Concat) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<Subgraph> node_subs;
|
||||||
|
for (uint32_t input_tensor_index : node->input_indices_) {
|
||||||
|
Tensor *tensor = &tensors_[input_tensor_index];
|
||||||
|
if (tensor->type_ == CONST) continue;
|
||||||
|
std::vector<uint32_t> input_nodes = tensor->out_nodes_;
|
||||||
|
Subgraph sub;
|
||||||
|
sub.ends_.push_back(input_nodes[0]);
|
||||||
|
InsertNodeByMid(input_nodes[0], &sub);
|
||||||
|
node_subs.push_back(sub);
|
||||||
|
}
|
||||||
|
node_sub_map_.insert(std::make_pair(node_index, node_subs));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto map : node_sub_map_) {
|
||||||
|
std::vector<Subgraph> &subgraphs = map.second;
|
||||||
|
|
||||||
|
if (std::any_of(subgraphs.begin(), subgraphs.end(), [&](Subgraph &sub) { return sub.nodes_.empty(); })) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!CheckIsParallelSubGraph(subgraphs)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// init graph device type
|
||||||
|
for (auto &subgraph : subgraphs) {
|
||||||
|
uint32_t head_node_index = subgraph.heads_.front();
|
||||||
|
subgraph.device_ = static_cast<lite::DeviceType>(model_->all_nodes_.at(head_node_index)->device_type_);
|
||||||
|
if (subgraph.device_ == DT_GPU) {
|
||||||
|
subgraph.thread_ = major_thread_;
|
||||||
|
subgraph.tid_ = 0;
|
||||||
|
} else {
|
||||||
|
subgraph.thread_ = minor_thread_;
|
||||||
|
subgraph.tid_ = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ConvertSubGraphToModel(&subgraphs);
|
||||||
|
}
|
||||||
|
InitMainGraphDevice(DT_CPU);
|
||||||
|
}
|
||||||
|
|
||||||
SearchSubGraph::SearchSubGraph(const InnerContext *context, Model *model, std::vector<lite::Tensor *> *src_tensors,
|
SearchSubGraph::SearchSubGraph(const InnerContext *context, Model *model, std::vector<lite::Tensor *> *src_tensors,
|
||||||
const std::map<int, OpParameter *> *op_parameters, std::vector<size_t> *output_nodes)
|
const std::map<int, OpParameter *> *op_parameters, std::vector<size_t> *output_nodes)
|
||||||
: output_nodes_(output_nodes), context_(context), src_tensors_(src_tensors), op_parameters_(op_parameters) {
|
: output_nodes_(output_nodes), context_(context), src_tensors_(src_tensors), op_parameters_(op_parameters) {
|
||||||
|
@ -756,29 +882,20 @@ void SearchSubGraph::InsertParallelNode(uint32_t index, Subgraph *subgraph) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (subgraph->search_terminate_) {
|
if (subgraph->search_terminate_) {
|
||||||
return;
|
if (!subgraph->nodes_.empty()) {
|
||||||
|
sub_graphs_.push_back(std::move(*subgraph));
|
||||||
|
}
|
||||||
|
Subgraph new_graph;
|
||||||
|
subgraph = &new_graph;
|
||||||
}
|
}
|
||||||
Model::Node *node = node_list_[index];
|
Model::Node *node = node_list_[index];
|
||||||
// has been searched
|
// has been searched
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// just deal with parallel target node
|
|
||||||
std::vector<uint32_t> input = node->input_indices_;
|
// if current node is parallel target node
|
||||||
/* remove const node */
|
if (IsOfflineParallelNode(node->primitive_, node->device_type_)) {
|
||||||
for (int i = static_cast<int>(input.size()) - 1; i >= 0; i--) {
|
|
||||||
if (tensors_[input[i]].type_ == CONST) {
|
|
||||||
VectorErase(&input, input[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// search to graph to graph input , terminate it.
|
|
||||||
if (std::any_of(input.begin(), input.end(), [&](int input_index) { return tensors_[input_index].type_ == INPUT; })) {
|
|
||||||
subgraph->search_terminate_ = true;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// if current node is no parallel target node, just judge terminate or continue
|
|
||||||
if (GetPrimitiveType(node->primitive_) == schema::PrimitiveType_Conv2DFusion &&
|
|
||||||
node->device_type_ != kDefaultDeviceType) {
|
|
||||||
// first searched
|
// first searched
|
||||||
if (subgraph->nodes_.empty()) {
|
if (subgraph->nodes_.empty()) {
|
||||||
subgraph->device_ = static_cast<DeviceType>(node->device_type_);
|
subgraph->device_ = static_cast<DeviceType>(node->device_type_);
|
||||||
|
@ -788,28 +905,28 @@ void SearchSubGraph::InsertParallelNode(uint32_t index, Subgraph *subgraph) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (IsNodeSubGraphHead(index, subgraph->nodes_)) {
|
|
||||||
if (subgraph->nodes_.empty()) {
|
|
||||||
subgraph->search_terminate_ = true;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
subgraph->heads_.push_back(subgraph->nodes_.front());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// for offline parallel target subgraph only has one end
|
|
||||||
if (subgraph->ends_.empty()) {
|
|
||||||
subgraph->ends_.push_back(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
|
subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
|
||||||
node_list_[index] = nullptr;
|
node_list_[index] = nullptr;
|
||||||
} else {
|
} else {
|
||||||
if (!subgraph->nodes_.empty()) {
|
subgraph->search_terminate_ = true;
|
||||||
return;
|
}
|
||||||
|
|
||||||
|
// just deal with parallel target node
|
||||||
|
std::vector<uint32_t> input = node->input_indices_;
|
||||||
|
|
||||||
|
/* remove const node */
|
||||||
|
for (int i = static_cast<int>(input.size()) - 1; i >= 0; i--) {
|
||||||
|
if (tensors_[input[i]].type_ == CONST) {
|
||||||
|
VectorErase(&input, input[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// search to graph to graph input , terminate it.
|
||||||
|
if (std::any_of(input.begin(), input.end(), [&](int input_index) { return tensors_[input_index].type_ == INPUT; })) {
|
||||||
|
subgraph->search_terminate_ = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// search for next nodes
|
// search for next nodes
|
||||||
for (uint32_t next : input) {
|
for (uint32_t next : input) {
|
||||||
auto next_nodes = tensors_[next].out_nodes_;
|
auto next_nodes = tensors_[next].out_nodes_;
|
||||||
|
@ -828,15 +945,8 @@ void SearchSubGraph::InitSearchParallelSubGraph() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SearchSubGraph::SubGraphSplitByOffLineParallel() {
|
|
||||||
MS_LOG(DEBUG) << "start to split offline parallel subgraph";
|
|
||||||
InitSearchParallelSubGraph();
|
|
||||||
ConvertSubGraphToModel(&sub_graphs_);
|
|
||||||
InitMainGraphDevice();
|
|
||||||
MS_LOG(DEBUG) << "end to split offline parallel subgraph";
|
|
||||||
}
|
|
||||||
|
|
||||||
void SearchSubGraph::SubGraphSplit() {
|
void SearchSubGraph::SubGraphSplit() {
|
||||||
|
UpdateOfflineParallelFlag();
|
||||||
if (offline_parallel_enable_) {
|
if (offline_parallel_enable_) {
|
||||||
SubGraphSplitByOffLineParallel();
|
SubGraphSplitByOffLineParallel();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -31,6 +31,8 @@
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
constexpr int kDefaultDeviceType = -1;
|
constexpr int kDefaultDeviceType = -1;
|
||||||
|
constexpr int kDefalutSubGraphSize = 2;
|
||||||
|
constexpr int kDefaultInputs = 1;
|
||||||
class SearchSubGraph {
|
class SearchSubGraph {
|
||||||
enum TensorType { NORMAL, CONST, INPUT };
|
enum TensorType { NORMAL, CONST, INPUT };
|
||||||
|
|
||||||
|
@ -122,6 +124,8 @@ class SearchSubGraph {
|
||||||
CostModel CalculateConv2DFusion(Model::Node *node);
|
CostModel CalculateConv2DFusion(Model::Node *node);
|
||||||
void dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
|
void dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
|
||||||
std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs);
|
std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs);
|
||||||
|
void UpdateOfflineParallelFlag();
|
||||||
|
bool CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<size_t> *output_nodes_ = nullptr;
|
std::vector<size_t> *output_nodes_ = nullptr;
|
||||||
|
|
|
@ -291,8 +291,6 @@ AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const AnfNodePt
|
||||||
concate_cnode->set_scope(conv_cnode->scope());
|
concate_cnode->set_scope(conv_cnode->scope());
|
||||||
std::vector<AnfNodePtr> outputs;
|
std::vector<AnfNodePtr> outputs;
|
||||||
GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs);
|
GetMultipleOutputsOfAnfNode(func_graph, concate_cnode, 1, &outputs);
|
||||||
// only support split_overlap node implementation, to split sub_graph for runtime
|
|
||||||
concate_cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(lite::DT_CPU)));
|
|
||||||
return concate_cnode;
|
return concate_cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -337,8 +335,6 @@ void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const AnfNo
|
||||||
ptr_list.push_back(value_node);
|
ptr_list.push_back(value_node);
|
||||||
}
|
}
|
||||||
split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
|
split_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list));
|
||||||
// only support split_overlap node implementation
|
|
||||||
split_cnode->AddAttr(mindspore::ops::kDeviceType, MakeValue(static_cast<int>(lite::DT_CPU)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -89,12 +89,12 @@ bool MultiConvSplit::CheckSplitValid() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int64_t split_axis_value_0 = UP_DIV(split_info_.ori_split_axis_value * visited_block, total_block_count);
|
int64_t split_axis_value_0 = UP_DIV(split_info_.ori_split_axis_value * visited_block, total_block_count);
|
||||||
if (split_axis_value_0 >= split_info_.ori_split_axis_value) {
|
if (split_axis_value_0 > split_info_.ori_split_axis_value) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int64_t split_axis_value_1 = split_info_.ori_split_axis_value - split_axis_value_0;
|
int64_t split_axis_value_1 = split_info_.ori_split_axis_value - split_axis_value_0;
|
||||||
split_axis_value_1 += (split_info_.extend_top.back() + split_info_.extend_bottom.back());
|
split_axis_value_1 += (split_info_.extend_top.back() + split_info_.extend_bottom.back());
|
||||||
return split_axis_value_1 < split_info_.ori_split_axis_value;
|
return split_axis_value_1 <= split_info_.ori_split_axis_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
int MultiConvSplit::GetMultiConvNodes(const AnfNodePtr &conv_node) {
|
int MultiConvSplit::GetMultiConvNodes(const AnfNodePtr &conv_node) {
|
||||||
|
|
Loading…
Reference in New Issue