forked from mindspore-Ecosystem/mindspore
sub graph split
This commit is contained in:
parent
18d79d35b6
commit
04e261ddc4
|
@ -28,6 +28,7 @@ option(ENABLE_VERBOSE "" off)
|
|||
option(ENABLE_SSE "if x86_64 support SSE instruction set" off)
|
||||
option(ENABLE_AVX "if x86_64 support SSE instruction set" off)
|
||||
option(ENABLE_MINDRT "if support mindrt" on)
|
||||
option(SUBGRAPH_SPLIT "if support sub graph split" off)
|
||||
|
||||
set(DIR_PREFIX mindspore-lite)
|
||||
set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})
|
||||
|
@ -57,6 +58,9 @@ else()
|
|||
set(PROCESS_UNIT cpu)
|
||||
endif()
|
||||
|
||||
if(SUBGRAPH_SPLIT)
|
||||
add_compile_definitions(SUBGRAPH_SPLIT)
|
||||
endif()
|
||||
|
||||
if(SUPPORT_NPU)
|
||||
set(DDK_PATH "$ENV{HWHIAI_DDK}/ddk/ai_ddk_lib")
|
||||
|
|
|
@ -132,6 +132,7 @@ set(LITE_SRC
|
|||
${LITE_DIR}/src/common/tensor_util.cc
|
||||
${LITE_DIR}/src/runtime/infer_manager.cc
|
||||
${LITE_DIR}/src/lite_model.cc
|
||||
${LITE_DIR}/src/sub_graph_split.cc
|
||||
${LITE_DIR}/src/tensorlist.cc
|
||||
${LITE_DIR}/src/tensor.cc
|
||||
${LITE_DIR}/src/dequant.cc
|
||||
|
|
|
@ -59,6 +59,7 @@ set(LITE_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_split.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc
|
||||
|
|
|
@ -43,6 +43,16 @@ void LiteKernel::FreeWorkspace() {
|
|||
free(workspace_);
|
||||
workspace_ = nullptr;
|
||||
}
|
||||
|
||||
int LiteKernel::DecOutTensorRefCount() {
|
||||
for (auto *tensor : this->out_tensors_) {
|
||||
tensor->set_ref_count(tensor->ref_count() - 1);
|
||||
if (0 >= tensor->ref_count()) {
|
||||
tensor->FreeData();
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
bool LiteKernel::IsReady(const std::vector<lite::Tensor *> &scope_tensors) {
|
||||
return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *in_tensor) {
|
||||
|
@ -66,16 +76,6 @@ void LiteKernel::InitOutTensorInitRefCount() {
|
|||
}
|
||||
}
|
||||
|
||||
int LiteKernel::DecOutTensorRefCount() {
|
||||
for (auto *tensor : this->out_tensors_) {
|
||||
tensor->set_ref_count(tensor->ref_count() - 1);
|
||||
if (0 >= tensor->ref_count()) {
|
||||
tensor->FreeData();
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int LiteKernel::FreeInWorkTensor() const {
|
||||
for (auto &in_tensor : this->in_tensors_) {
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
|
|
|
@ -35,7 +35,16 @@ static constexpr int kPerTensor = 1;
|
|||
static constexpr size_t kPerBatch = 3;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
enum KERNEL_ARCH { kCPU, kGPU, kAPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU };
|
||||
enum KERNEL_ARCH {
|
||||
kCPU,
|
||||
kGPU,
|
||||
kAPU,
|
||||
kNPU,
|
||||
kALL, /* Support GPU NPU CPU */
|
||||
kKernelArch_MIN = kCPU,
|
||||
kKernelArch_MAX = kALL
|
||||
};
|
||||
|
||||
struct KernelKey {
|
||||
KERNEL_ARCH arch;
|
||||
TypeId data_type;
|
||||
|
@ -161,8 +170,6 @@ class LiteKernel {
|
|||
|
||||
virtual void InitOutTensorInitRefCount();
|
||||
|
||||
int DecOutTensorRefCount();
|
||||
|
||||
virtual int FreeInWorkTensor() const;
|
||||
|
||||
KernelKey desc() const { return desc_; }
|
||||
|
@ -171,6 +178,8 @@ class LiteKernel {
|
|||
|
||||
SubGraphType subgraph_type() const { return this->subgraph_type_; }
|
||||
|
||||
const lite::InnerContext *context() const { return this->context_; }
|
||||
|
||||
virtual std::string ToString() const;
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
|
@ -179,6 +188,7 @@ class LiteKernel {
|
|||
static void AllocWorkspace(size_t size);
|
||||
static void FreeWorkspace();
|
||||
void *workspace() { return workspace_; }
|
||||
int DecOutTensorRefCount();
|
||||
#endif
|
||||
|
||||
protected:
|
||||
|
|
|
@ -32,7 +32,7 @@ int LiteOpActor::CompileArrow() {
|
|||
}
|
||||
}
|
||||
if (to_input_index == -1) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
auto id = out->name() + this->GetAID().Url();
|
||||
auto arrow = std::make_shared<OpArrow>(i, id, to_input_index);
|
||||
|
@ -41,12 +41,19 @@ int LiteOpActor::CompileArrow() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
output_op_arrows_.emplace_back(std::move(arrow));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) {
|
||||
for (auto op_arrow : output_op_arrows_) {
|
||||
auto data = context->outputData_->at(op_arrow->from_output_index_);
|
||||
Async(op_arrow->to_op_id_, &mindspore::OpActor<Tensor>::RunOpData, data, context);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void LiteOpActor::SetOutputData(OpContext<Tensor> *context) {
|
||||
auto size = context->outputData_->size();
|
||||
MS_ASSERT(size == context->results_->size());
|
||||
|
|
|
@ -50,6 +50,7 @@ class LiteOpActor : public OpActor<lite::Tensor> {
|
|||
return;
|
||||
}
|
||||
input_op_datas_.erase(op_uuid);
|
||||
AsyncOutput(context);
|
||||
SetOutputData(context);
|
||||
}
|
||||
void Init() {
|
||||
|
@ -83,6 +84,7 @@ class LiteOpActor : public OpActor<lite::Tensor> {
|
|||
|
||||
private:
|
||||
void SetOutputData(OpContext<Tensor> *context);
|
||||
void AsyncOutput(OpContext<Tensor> *context);
|
||||
|
||||
kernel::LiteKernel *kernel_;
|
||||
};
|
||||
|
|
|
@ -104,6 +104,12 @@ void LiteModel::Free() {
|
|||
tensor_buf = nullptr;
|
||||
}
|
||||
attr_tensor_bufs_.resize(0);
|
||||
|
||||
for (auto &node_buf : node_bufs_) {
|
||||
free(node_buf);
|
||||
node_buf = nullptr;
|
||||
}
|
||||
node_bufs_.resize(0);
|
||||
}
|
||||
|
||||
void LiteModel::Destroy() {
|
||||
|
|
|
@ -192,6 +192,7 @@ class LiteModel : public Model {
|
|||
|
||||
public:
|
||||
size_t buf_size_ = 0;
|
||||
std::vector<char *> node_bufs_;
|
||||
|
||||
protected:
|
||||
std::vector<char *> attr_tensor_bufs_;
|
||||
|
|
|
@ -399,8 +399,15 @@ int LiteSession::CompileGraph(Model *model) {
|
|||
#endif
|
||||
InitGraphInOutTensors(model);
|
||||
|
||||
ret = PrepareKernels(model);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare kernels failed: " << ret;
|
||||
is_running_.store(false);
|
||||
return ret;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MINDRT
|
||||
if (context_->IsCpuEnabled() && !context_->IsGpuEnabled() && !context_->IsNpuEnabled() && kernels_.size() == 1) {
|
||||
if (kernels_.size() == 1) {
|
||||
executor_ = new (std::nothrow) MindrtExecutor();
|
||||
} else {
|
||||
executor_ = new (std::nothrow) Executor();
|
||||
|
@ -420,16 +427,10 @@ int LiteSession::CompileGraph(Model *model) {
|
|||
is_running_.store(false);
|
||||
return ret;
|
||||
}
|
||||
ret = PrepareKernels(model);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare kernels failed: " << ret;
|
||||
is_running_.store(false);
|
||||
return ret;
|
||||
}
|
||||
|
||||
is_running_.store(false);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
||||
int LiteSession::PrepareKernels(Model *model) {
|
||||
std::vector<kernel::LiteKernel *> all_kernels;
|
||||
|
|
|
@ -102,10 +102,7 @@ int NPUExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector<
|
|||
|
||||
memcpy(npu_input_tensors_[i]->GetBuffer(), data, in_tensors[index]->Size());
|
||||
inputs_visited[index] = true;
|
||||
in_tensors[index]->set_ref_count(in_tensors[index]->ref_count() - 1);
|
||||
if (in_tensors[index]->ref_count() <= 0) {
|
||||
in_tensors[index]->FreeData();
|
||||
}
|
||||
in_tensors[index]->DecRefCount();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,6 +38,7 @@ class SubGraphNpuKernel : public SubGraphKernel {
|
|||
const lite::InnerContext *ctx = nullptr, lite::NPUManager *npu_manager = nullptr)
|
||||
: SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx), npu_manager_(npu_manager) {
|
||||
subgraph_type_ = kNpuSubGraph;
|
||||
desc_.arch = kernel::KERNEL_ARCH::kNPU;
|
||||
}
|
||||
|
||||
~SubGraphNpuKernel() override;
|
||||
|
|
|
@ -70,7 +70,7 @@ class DefaultAllocator : public Allocator {
|
|||
std::multimap<size_t, MemBuf *> freeList_;
|
||||
// 6 is empirical value
|
||||
int shiftFactor_ = 6;
|
||||
bool lockFlag_ = false;
|
||||
bool lockFlag_ = true;
|
||||
};
|
||||
|
||||
constexpr int64_t MAX_MALLOC_SIZE = static_cast<size_t>(2000) * 1024 * 1024;
|
||||
|
|
|
@ -34,6 +34,7 @@ class OpenCLSubGraph : public SubGraphKernel {
|
|||
: SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx) {
|
||||
ocl_runtime_ = ocl_runtime_wrap_.GetInstance();
|
||||
subgraph_type_ = kGpuSubGraph;
|
||||
desc_.arch = kernel::KERNEL_ARCH::kGPU;
|
||||
this->name_ = "GpuSubGraph";
|
||||
nodes_set_.insert(nodes.begin(), nodes.end());
|
||||
all_kernels_infer_done_ = std::all_of(nodes_.begin(), nodes_.end(), [](const kernel::LiteKernel *kernel) {
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "src/common/version_manager.h"
|
||||
#include "src/common/prim_util.h"
|
||||
#include "src/runtime/infer_manager.h"
|
||||
#include "src/sub_graph_split.h"
|
||||
#include "src/dequant.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#if GPU_OPENCL
|
||||
|
@ -71,6 +72,12 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
|||
}
|
||||
|
||||
this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_);
|
||||
|
||||
#ifdef SUBGRAPH_SPLIT
|
||||
auto search_sub_graph = SearchSubGraph(src_model_, this->graph_output_node_indexes_);
|
||||
search_sub_graph.SubGraphSplitByOutput();
|
||||
#endif
|
||||
|
||||
bool infer_shape_interrupt = false;
|
||||
auto ret = InferSubGraphShape(kMainSubGraphIndex, &infer_shape_interrupt);
|
||||
if (ret != RET_OK) {
|
||||
|
@ -89,7 +96,11 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
|||
MS_LOG(ERROR) << "Schedule run pass failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = ConstructSubGraphs(dst_kernels);
|
||||
|
||||
auto src_kernel = *dst_kernels;
|
||||
dst_kernels->clear();
|
||||
std::map<const kernel::LiteKernel *, bool> is_kernel_finish;
|
||||
ret = ConstructSubGraphs(src_kernel, dst_kernels, &is_kernel_finish);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConstructSubGraphs failed.";
|
||||
return ret;
|
||||
|
@ -465,6 +476,14 @@ kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node *
|
|||
MS_LOG(ERROR) << "Schedule partial failed, name: " << src_node->name_;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FindAllInoutKernels(sub_kernels);
|
||||
ret = RunPass(&sub_kernels);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SchedulePartialToKernel run pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(sub_kernels.front());
|
||||
auto subgraph = CreateSubGraphKernel(sub_kernels, &in_tensors, &out_tensors, cur_sub_graph_type);
|
||||
subgraph->set_name("subgraph_" + src_node->name_);
|
||||
|
@ -594,35 +613,33 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels(
|
|||
return sub_kernels;
|
||||
}
|
||||
|
||||
int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) {
|
||||
auto old_kernels = *kernels;
|
||||
kernels->clear();
|
||||
std::map<const kernel::LiteKernel *, bool> is_kernel_finish;
|
||||
for (auto kernel : old_kernels) {
|
||||
is_kernel_finish[kernel] = false;
|
||||
int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel,
|
||||
std::vector<kernel::LiteKernel *> *dst_kernel,
|
||||
std::map<const kernel::LiteKernel *, bool> *is_kernel_finish) {
|
||||
for (auto kernel : src_kernel) {
|
||||
(*is_kernel_finish)[kernel] = false;
|
||||
}
|
||||
|
||||
while (true) {
|
||||
auto head_kernel_iter = std::find_if(old_kernels.begin(), old_kernels.end(), [&](const kernel::LiteKernel *kernel) {
|
||||
auto head_kernel_iter = std::find_if(src_kernel.begin(), src_kernel.end(), [&](const kernel::LiteKernel *kernel) {
|
||||
auto kernel_inputs = kernel->in_kernels();
|
||||
if (is_kernel_finish[kernel]) {
|
||||
if ((*is_kernel_finish)[kernel]) {
|
||||
return false;
|
||||
}
|
||||
// when merge is removed, this if is removed automatically
|
||||
if (kernel->Type() == schema::PrimitiveType_Merge) {
|
||||
return MergeOpIsReady(kernel, is_kernel_finish);
|
||||
return MergeOpIsReady(kernel, (*is_kernel_finish));
|
||||
} else {
|
||||
return std::all_of(kernel_inputs.begin(), kernel_inputs.end(),
|
||||
[&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; });
|
||||
[&](kernel::LiteKernel *kernel) { return (*is_kernel_finish)[kernel]; });
|
||||
}
|
||||
});
|
||||
if (head_kernel_iter == old_kernels.end()) {
|
||||
if (head_kernel_iter == src_kernel.end()) {
|
||||
break;
|
||||
}
|
||||
auto head_kernel = *head_kernel_iter;
|
||||
if (head_kernel->subgraph_type() != kernel::kNotSubGraph) {
|
||||
is_kernel_finish[head_kernel] = true;
|
||||
kernels->emplace_back(head_kernel);
|
||||
(*is_kernel_finish)[head_kernel] = true;
|
||||
dst_kernel->push_back(head_kernel);
|
||||
continue;
|
||||
}
|
||||
if (head_kernel->desc().arch == mindspore::kernel::kAPU) {
|
||||
|
@ -630,15 +647,15 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) {
|
|||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel);
|
||||
auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_finish);
|
||||
auto sub_kernels = FindAllSubGraphKernels(head_kernel, is_kernel_finish);
|
||||
auto subgraph = CreateSubGraphKernel(sub_kernels, nullptr, nullptr, cur_sub_graph_type);
|
||||
if (subgraph == nullptr) {
|
||||
MS_LOG(ERROR) << "Create SubGraphKernel failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
kernels->emplace_back(subgraph);
|
||||
dst_kernel->emplace_back(subgraph);
|
||||
}
|
||||
for (auto *subgraph : *kernels) {
|
||||
for (auto *subgraph : *dst_kernel) {
|
||||
auto ret = subgraph->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init SubGraph failed: " << ret;
|
||||
|
@ -832,6 +849,7 @@ int Scheduler::RunPass(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
|||
npu_pass_manager_->AddPass(fusion_pass);
|
||||
|
||||
ret = npu_pass_manager_->Run();
|
||||
npu_pass_manager_->Clear();
|
||||
#endif
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -74,7 +74,8 @@ class Scheduler {
|
|||
static void FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels);
|
||||
|
||||
// vector<LiteKernel/SubGraphKernel> --> vector<SubGraphKernel>
|
||||
int ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels);
|
||||
int ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel, std::vector<kernel::LiteKernel *> *dst_kernel,
|
||||
std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map);
|
||||
|
||||
// create subgraph_kernel from a vector of kernel
|
||||
kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels,
|
||||
|
|
|
@ -128,6 +128,7 @@ class CpuSubGraph : public SubGraphKernel {
|
|||
std::vector<LiteKernel *> nodes, const lite::InnerContext *ctx)
|
||||
: SubGraphKernel(inputs, outputs, std::move(in_kernels), std::move(out_kernels), std::move(nodes), ctx) {
|
||||
subgraph_type_ = kCpuFP32SubGraph;
|
||||
desc_.arch = kernel::KERNEL_ARCH::kCPU;
|
||||
}
|
||||
|
||||
~CpuSubGraph() override { delete this->executor_; }
|
||||
|
|
|
@ -0,0 +1,269 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/sub_graph_split.h"
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "src/tensor.h"
|
||||
#include "schema/inner/ops_generated.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
#ifdef SUBGRAPH_SPLIT
|
||||
const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph_index) {
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
auto val_offset = schema::CreatePartialFusion(fbb, subgraph_index);
|
||||
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PartialFusion, val_offset.o);
|
||||
fbb.Finish(prim_offset);
|
||||
auto tmp_buf = fbb.GetBufferPointer();
|
||||
auto prim_buf = reinterpret_cast<char *>(malloc(fbb.GetSize()));
|
||||
memcpy(prim_buf, tmp_buf, fbb.GetSize());
|
||||
|
||||
auto primitive = flatbuffers::GetRoot<schema::Primitive>(prim_buf);
|
||||
fbb.Clear();
|
||||
|
||||
model_->node_bufs_.push_back(prim_buf);
|
||||
return std::move(primitive);
|
||||
}
|
||||
|
||||
void SearchSubGraph::ConvertSubGraphToModel() {
|
||||
Model::SubGraph *main_graphs = model_->sub_graphs_.front();
|
||||
|
||||
for (Subgraph &subgraph : sub_graphs_) {
|
||||
if (subgraph.nodes_.empty()) {
|
||||
continue;
|
||||
}
|
||||
mindspore::kernel::KERNEL_ARCH device = subgraph.device_;
|
||||
|
||||
int new_sub_index = model_->sub_graphs_.size();
|
||||
int partial_index = model_->all_nodes_.size();
|
||||
|
||||
Model::SubGraph *new_sub_graph = new (std::nothrow) Model::SubGraph();
|
||||
if (new_sub_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "New sub graph failed!";
|
||||
return;
|
||||
}
|
||||
new_sub_graph->name_ = "Subgraph-split-" + std::to_string(new_sub_index);
|
||||
|
||||
Model::Node *new_partial_node = new (std::nothrow) Model::Node();
|
||||
if (new_partial_node == nullptr) {
|
||||
MS_LOG(ERROR) << "New partial node failed!";
|
||||
return;
|
||||
}
|
||||
new_partial_node->name_ = "Partial-subgraph-split-" + std::to_string(new_sub_index);
|
||||
new_partial_node->node_type_ = mindspore::lite::NodeType_ValueNode;
|
||||
new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index);
|
||||
|
||||
while (!subgraph.nodes_.empty()) {
|
||||
uint32_t node_index = subgraph.nodes_.front();
|
||||
new_sub_graph->node_indices_.push_back(node_index);
|
||||
VectorErase(&main_graphs->node_indices_, node_index);
|
||||
VectorErase(&subgraph.nodes_, node_index);
|
||||
model_->all_nodes_[node_index]->device_type_ = device;
|
||||
}
|
||||
|
||||
for (uint32_t head_index : subgraph.heads_) {
|
||||
Model::Node *head_node = model_->all_nodes_[head_index];
|
||||
std::vector<uint32_t> inputs = head_node->input_indices_;
|
||||
for (auto input : inputs) {
|
||||
if (tensors_[input].type_ == CONST) {
|
||||
continue;
|
||||
}
|
||||
if (std::find(new_sub_graph->input_indices_.begin(), new_sub_graph->input_indices_.end(), input) !=
|
||||
new_sub_graph->input_indices_.end()) {
|
||||
continue;
|
||||
}
|
||||
new_sub_graph->input_indices_.insert(new_sub_graph->input_indices_.end(), input);
|
||||
new_partial_node->input_indices_.insert(new_partial_node->input_indices_.end(), input);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t end_index : subgraph.ends_) {
|
||||
Model::Node *end_node = model_->all_nodes_[end_index];
|
||||
std::vector<uint32_t> outputs = end_node->output_indices_;
|
||||
new_sub_graph->output_indices_.insert(new_sub_graph->output_indices_.end(), outputs.begin(), outputs.end());
|
||||
new_partial_node->output_indices_.insert(new_partial_node->output_indices_.end(), outputs.begin(), outputs.end());
|
||||
}
|
||||
|
||||
main_graphs->node_indices_.push_back(partial_index);
|
||||
model_->all_nodes_.push_back(std::move(new_partial_node));
|
||||
model_->sub_graphs_.push_back(std::move(new_sub_graph));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes) {
|
||||
std::vector<uint32_t> output_indexes = node_list_[node_index]->output_indices_;
|
||||
std::vector<uint32_t> output_nodes;
|
||||
for (uint32_t out_t : output_indexes) {
|
||||
std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
|
||||
output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end());
|
||||
}
|
||||
for (uint32_t out_n : output_nodes) {
|
||||
if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InsertNode(uint32_t index, Subgraph *subgraph) {
|
||||
if (subgraph->search_terminate_) {
|
||||
return;
|
||||
}
|
||||
|
||||
Model::Node *node = node_list_[index];
|
||||
if (node == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> input = node->input_indices_;
|
||||
/* remove const node */
|
||||
for (int i = input.size() - 1; i >= 0; i--) {
|
||||
if (tensors_[input[i]].type_ == CONST) {
|
||||
input.erase(input.begin() + i);
|
||||
}
|
||||
}
|
||||
|
||||
/* all node_input is graph_input */
|
||||
for (size_t i = 0; i < input.size(); i++) {
|
||||
if (tensors_[input[i]].type_ != INPUT) {
|
||||
break;
|
||||
}
|
||||
subgraph->heads_.clear();
|
||||
subgraph->ends_.clear();
|
||||
subgraph->nodes_.clear();
|
||||
subgraph->search_terminate_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
/* split in graph */
|
||||
if (IsNodeSubGraphHead(index, subgraph->nodes_)) {
|
||||
if (subgraph->nodes_.empty()) {
|
||||
subgraph->search_terminate_ = true;
|
||||
return;
|
||||
}
|
||||
subgraph->heads_.push_back(subgraph->nodes_.front());
|
||||
return;
|
||||
}
|
||||
|
||||
if (find(output_nodes_.begin(), output_nodes_.end(), index) != output_nodes_.end()) {
|
||||
subgraph->ends_.push_back(index);
|
||||
}
|
||||
|
||||
/* node insert in current subgraph */
|
||||
subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
|
||||
node_list_[index] = nullptr;
|
||||
|
||||
/* search for next node */
|
||||
for (uint32_t in : input) {
|
||||
auto next_nodes = tensors_[in].out_nodes_;
|
||||
for (uint32_t next_node : next_nodes) {
|
||||
InsertNode(next_node, subgraph);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitSearchSubGraph() {
|
||||
for (uint32_t out : output_nodes_) {
|
||||
Subgraph subgraph;
|
||||
|
||||
InsertNode(out, &subgraph);
|
||||
|
||||
sub_graphs_.push_back(std::move(subgraph));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitSearchTensor() {
|
||||
tensors_.resize(model_->all_tensors_.size());
|
||||
|
||||
/* Set Tensor Type */
|
||||
for (size_t i = 0; i < tensors_.size(); i++) {
|
||||
tensors_[i].type_ = NORMAL;
|
||||
mindspore::schema::Tensor *src_tensor = model_->all_tensors_[i];
|
||||
auto category = TensorCategory(src_tensor);
|
||||
if (category == mindspore::lite::Tensor::Category::CONST_TENSOR ||
|
||||
category == mindspore::lite::Tensor::Category::CONST_SCALAR) {
|
||||
tensors_[i].type_ = CONST;
|
||||
}
|
||||
}
|
||||
std::vector<uint32_t> graph_input = model_->sub_graphs_[0]->input_indices_;
|
||||
for (auto in : graph_input) {
|
||||
tensors_[in].type_ = INPUT;
|
||||
}
|
||||
|
||||
/* Set Tensor In and out Node */
|
||||
for (size_t index = 0; index < model_->all_nodes_.size(); index++) {
|
||||
Model::Node *node = model_->all_nodes_[index];
|
||||
std::vector<uint32_t> input = node->input_indices_;
|
||||
for (uint32_t in : input) {
|
||||
tensors_[in].in_nodes_.push_back(index);
|
||||
}
|
||||
std::vector<uint32_t> output = node->output_indices_;
|
||||
for (uint32_t out : output) {
|
||||
tensors_[out].out_nodes_.push_back(index);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitSubgraphDevice() {
|
||||
sub_graphs_[0].device_ = kernel::KERNEL_ARCH::kCPU;
|
||||
sub_graphs_[1].device_ = kernel::KERNEL_ARCH::kALL;
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitMainGraphDevice() {
|
||||
kernel::KERNEL_ARCH main_device = kernel::KERNEL_ARCH::kALL;
|
||||
Model::SubGraph *main_graph = model_->sub_graphs_.front();
|
||||
for (uint32_t node_index : main_graph->node_indices_) {
|
||||
Model::Node *node = model_->all_nodes_[node_index];
|
||||
node->device_type_ = main_device;
|
||||
}
|
||||
}
|
||||
|
||||
void SearchSubGraph::SubgraphFusion() {
|
||||
Subgraph new_npu_sub;
|
||||
Subgraph &npu_sub1 = sub_graphs_[1];
|
||||
Subgraph &npu_sub2 = sub_graphs_[2];
|
||||
new_npu_sub.nodes_.insert(new_npu_sub.nodes_.end(), npu_sub1.nodes_.begin(), npu_sub1.nodes_.end());
|
||||
new_npu_sub.nodes_.insert(new_npu_sub.nodes_.end(), npu_sub2.nodes_.begin(), npu_sub2.nodes_.end());
|
||||
new_npu_sub.heads_.insert(new_npu_sub.heads_.end(), npu_sub1.heads_.begin(), npu_sub1.heads_.end());
|
||||
new_npu_sub.heads_.insert(new_npu_sub.heads_.end(), npu_sub2.heads_.begin(), npu_sub2.heads_.end());
|
||||
new_npu_sub.ends_.insert(new_npu_sub.ends_.end(), npu_sub1.ends_.begin(), npu_sub1.ends_.end());
|
||||
new_npu_sub.ends_.insert(new_npu_sub.ends_.end(), npu_sub2.ends_.begin(), npu_sub2.ends_.end());
|
||||
sub_graphs_.erase(sub_graphs_.begin() + 2);
|
||||
sub_graphs_.erase(sub_graphs_.begin() + 1);
|
||||
sub_graphs_.insert(sub_graphs_.end(), std::move(new_npu_sub));
|
||||
return;
|
||||
}
|
||||
|
||||
void SearchSubGraph::SubGraphSplitByOutput() {
|
||||
InitSearchTensor();
|
||||
|
||||
InitSearchSubGraph();
|
||||
|
||||
SubgraphFusion();
|
||||
|
||||
InitSubgraphDevice();
|
||||
|
||||
ConvertSubGraphToModel();
|
||||
|
||||
InitMainGraphDevice();
|
||||
}
|
||||
#endif
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_SUB_GRAPH_SPLIT_H_
|
||||
#define MINDSPORE_LITE_SRC_SUB_GRAPH_SPLIT_H_
|
||||
|
||||
#include <stack>
|
||||
#include <vector>
|
||||
#include "include/model.h"
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/lite_model.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
#ifdef SUBGRAPH_SPLIT
|
||||
class SearchSubGraph {
|
||||
enum TensorType { NORMAL, CONST, INPUT };
|
||||
|
||||
struct Tensor {
|
||||
std::vector<uint32_t> in_nodes_; /* used current tensor as input */
|
||||
std::vector<uint32_t> out_nodes_;
|
||||
TensorType type_;
|
||||
};
|
||||
|
||||
struct Subgraph {
|
||||
std::vector<uint32_t> nodes_;
|
||||
std::vector<uint32_t> heads_;
|
||||
std::vector<uint32_t> ends_;
|
||||
bool search_terminate_ = false;
|
||||
mindspore::kernel::KERNEL_ARCH device_;
|
||||
};
|
||||
|
||||
public:
|
||||
SearchSubGraph(Model *model, std::vector<size_t> output_nodes) {
|
||||
output_nodes_.insert(output_nodes_.end(), output_nodes.begin(), output_nodes.end());
|
||||
node_list_ = model->all_nodes_;
|
||||
model_ = reinterpret_cast<LiteModel *>(model);
|
||||
}
|
||||
~SearchSubGraph() = default;
|
||||
|
||||
public:
|
||||
void SubGraphSplitByOutput();
|
||||
|
||||
private:
|
||||
void InitSearchTensor();
|
||||
void InitSearchSubGraph();
|
||||
void ConvertSubGraphToModel();
|
||||
void InsertNode(uint32_t index, Subgraph *subgraph);
|
||||
bool IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes);
|
||||
const schema::Primitive *CreatePartialPrimitive(int64_t subgraph_index);
|
||||
void InitSubgraphDevice();
|
||||
void SubgraphFusion();
|
||||
void InitMainGraphDevice();
|
||||
|
||||
private:
|
||||
LiteModel *model_ = nullptr;
|
||||
std::vector<Tensor> tensors_;
|
||||
std::vector<Subgraph> sub_graphs_;
|
||||
std::vector<size_t> output_nodes_;
|
||||
std::vector<Model::Node *> node_list_;
|
||||
};
|
||||
|
||||
#endif
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_SUB_GRAPH_SPLIT_H_
|
|
@ -352,8 +352,8 @@ void Tensor::DecRefCount() {
|
|||
if (this->IsConst() || this->IsGraphInput()) {
|
||||
return;
|
||||
}
|
||||
this->ref_count_--;
|
||||
if (this->ref_count_ <= 0) {
|
||||
bool free_data = --ref_count_ <= 0;
|
||||
if (free_data) {
|
||||
FreeData();
|
||||
this->ref_count_ = 0;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <string>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <atomic>
|
||||
#include "include/ms_tensor.h"
|
||||
#include "src/runtime/allocator.h"
|
||||
|
||||
|
@ -205,7 +206,7 @@ class Tensor : public mindspore::tensor::MSTensor {
|
|||
std::vector<int> shape_;
|
||||
schema::Format format_;
|
||||
Category category_;
|
||||
size_t ref_count_ = 0;
|
||||
std::atomic_int ref_count_ = 0;
|
||||
size_t init_ref_count_ = 0;
|
||||
std::vector<QuantArg> quant_params_;
|
||||
std::vector<float> quant_clusters_;
|
||||
|
|
|
@ -144,6 +144,7 @@ set(TEST_LITE_SRC
|
|||
${LITE_DIR}/src/dequant.cc
|
||||
${LITE_DIR}/src/huffman_decode.cc
|
||||
${LITE_DIR}/src/sub_graph_kernel.cc
|
||||
${LITE_DIR}/src/sub_graph_split.cc
|
||||
${LITE_DIR}/src/lite_model.cc
|
||||
${LITE_DIR}/src/scheduler.cc
|
||||
${LITE_DIR}/src/common/graph_util.cc
|
||||
|
|
|
@ -109,6 +109,7 @@ set(LITE_SRC
|
|||
${SRC_DIR}/lite_kernel.cc
|
||||
${SRC_DIR}/scheduler.cc
|
||||
${SRC_DIR}/sub_graph_kernel.cc
|
||||
${SRC_DIR}/sub_graph_split.cc
|
||||
${SRC_DIR}/lite_session.cc
|
||||
${SRC_DIR}/executor.cc
|
||||
${SRC_DIR}/lite_model.cc
|
||||
|
|
Loading…
Reference in New Issue