!49445 [MS]{Lite][Task] optimize new infer framework

Merge pull request !49445 from 刘力力/feature_new_infer_merge
This commit is contained in:
i-robot 2023-02-27 07:42:31 +00:00 committed by Gitee
commit b127cd5015
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
30 changed files with 799 additions and 178 deletions

View File

@ -0,0 +1,63 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "extendrt/execution_flow.h"
#include "litert/lite_kernel.h"
#include "litert/kernel_exec_util.h"
#include "litert/sub_graph_kernel.h"
namespace mindspore::infer {
ExecutionFlow::~ExecutionFlow() {
for (auto tensor : inputs_) {
if (tensor != nullptr) {
delete tensor;
}
}
for (auto tensor : outputs_) {
if (tensor != nullptr) {
delete tensor;
}
}
for (auto kernel : kernels_) {
if (kernel != nullptr) {
delete kernel;
}
}
}
abstract::Kernel *ExecutionFlow::ConstructFusionKernel() {
auto lite_kernel = new (std::nothrow) mindspore::kernel::LiteKernel(nullptr, inputs_, outputs_, context_);
if (lite_kernel == nullptr) {
MS_LOG(ERROR) << "ExecutionFlow::ConstructFusionKernel create lite kernel failed, may be memory is not enough";
return nullptr;
}
std::vector<KernelExec *> input_kernels = mindspore::kernel::KernelExecUtil::SubgraphInputNodes(kernels_);
std::vector<KernelExec *> output_kernels = mindspore::kernel::KernelExecUtil::SubgraphOutputNodes(kernels_);
mindspore::kernel::SubGraphKernel *sub_graph_kernel =
new CpuFp32SubGraph(input_kernels, output_kernels, kernels_, lite_kernel);
if (sub_graph_kernel == nullptr) {
MS_LOG(ERROR) << "ExecutionFlow::ConstructFusionKernel create sub graph kernel failed, may be memory is not enough";
delete lite_kernel return nullptr;
}
sub_graph_kernel->set_context(context_);
return sub_graph_kernel;
}
} // namespace mindspore::infer

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_SRC_EXTENDRT_EXECUTION_FLOW_H_
#include <vector>
#include <memory>
#include "infer/execution_flow.h"
@ -25,23 +26,7 @@ namespace mindspore::infer {
class ExecutionFlow : public abstract::ExecutionFlow {
public:
ExecutionFlow() = default;
virtual ~ExecutionFlow() {
for (auto tensor : inputs_) {
if (tensor != nullptr) {
delete tensor;
}
}
for (auto tensor : outputs_) {
if (tensor != nullptr) {
delete tensor;
}
}
for (auto kernel : kernels_) {
if (kernel != nullptr) {
delete kernel;
}
}
}
virtual ~ExecutionFlow();
std::vector<abstract::Kernel *> GetKernels() override { return kernels_; }
@ -55,9 +40,9 @@ class ExecutionFlow : public abstract::ExecutionFlow {
void SetOutputs(const std::vector<abstract::Tensor *> &outputs) override { outputs_ = outputs; }
abstract::Context *GetContext() override { return context_; }
std::shared_ptr<abstract::Context> GetContext() override { return context_; }
void SetContext(abstract::Context *context) override { context_ = context; }
void SetContext(std::shared_ptr<abstract::Context> context) override { context_ = context; }
const abstract::KernelCallBack &GetKernelBeforeCallBack() override { return before_; }
@ -67,11 +52,13 @@ class ExecutionFlow : public abstract::ExecutionFlow {
void SetKernelAfterCallBack(const abstract::KernelCallBack &callback) override { after_ = callback; }
abstract::Kernel *ConstructFusionKernel() override;
private:
std::vector<abstract::Kernel *> kernels_;
std::vector<abstract::Tensor *> inputs_;
std::vector<abstract::Tensor *> outputs_;
abstract::Context *context_;
std::shared_ptr<abstract::Context> context_;
abstract::KernelCallBack before_;
abstract::KernelCallBack after_;
};

View File

@ -0,0 +1,64 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "extendrt/execution_plan.h"
#include "litert/lite_kernel.h"
#include "litert/kernel_exec_util.h"
#include "litert/sub_graph_kernel.h"
namespace mindspore::infer {
ExecutionPlan::~ExecutionPlan() {
if (input_isolate_map_ != nullptr) {
delete input_isolate_map_;
input_isolate_map_ = nullptr;
}
if (output_isolate_map_) {
delete output_isolate_map_;
output_isolate_map_ = nullptr;
}
for (auto tensor : inputs_) {
if (tensor != nullptr) {
delete tensor;
}
}
for (auto tensor : outputs_) {
if (tensor != nullptr) {
delete tensor;
}
}
}
std::vector<abstract::Kernel *> ExecutionPlan::ToKernelList() {
std::vector<abstract::Kernel *> kernels;
for (auto flow : execution_flows_) {
if (flow == nullptr) {
MS_LOG(ERROR) << "ExecutionPlan::ToKernelList get nullptr execution flow.";
return std::vector<abstract::Kernel *>{};
}
auto kernel = flow->ConstructFusionKernel();
if (kernel == nullptr) {
MS_LOG(ERROR) << "ExecutionPlan::ToKernelList construct execution flow to Sub Graph Kernel failed.";
return std::vector<abstract::Kernel *>{};
}
kernels.emplace_back(kernel);
}
return kernels;
}
} // namespace mindspore::infer

View File

@ -27,27 +27,7 @@ namespace mindspore::infer {
class ExecutionPlan : public abstract::ExecutionPlan {
public:
ExecutionPlan() = default;
virtual ~ExecutionPlan() {
if (input_isolate_map_ != nullptr) {
delete input_isolate_map_;
input_isolate_map_ = nullptr;
}
if (output_isolate_map_) {
delete output_isolate_map_;
output_isolate_map_ = nullptr;
}
for (auto tensor : inputs_) {
if (tensor != nullptr) {
delete tensor;
}
}
for (auto tensor : outputs_) {
if (tensor != nullptr) {
delete tensor;
}
}
}
virtual ~ExecutionPlan();
std::vector<std::shared_ptr<abstract::ExecutionFlow>> GetExecutionFLows() override { return execution_flows_; }
@ -55,7 +35,7 @@ class ExecutionPlan : public abstract::ExecutionPlan {
execution_flows_ = execution_flows;
}
void AddExecutionFlow(std::shared_ptr<ExecutionFlow> execution_flow) override {
void AddExecutionFlow(std::shared_ptr<abstract::ExecutionFlow> execution_flow) override {
execution_flows_.emplace_back(execution_flow);
}
@ -71,25 +51,42 @@ class ExecutionPlan : public abstract::ExecutionPlan {
void SetOutputs(const std::vector<abstract::Tensor *> &outputs) override { outputs_ = outputs; }
void SetInputsMap(std::unordered_map<Tensor *, Tensor *> *input_isolate_map) {
std::shared_ptr<abstract::Context> GetContext() override { return context_; }
void SetContext(std::shared_ptr<abstract::Context> context) override { context_ = context; }
const abstract::KernelCallBack &GetKernelBeforeCallBack() override { return before_; }
void SetKernelBeforeCallBack(const abstract::KernelCallBack &callback) override { before_ = callback; }
const abstract::KernelCallBack &GetKernelAfterCallBack() override { return after_; }
void SetKernelAfterCallBack(const abstract::KernelCallBack &callback) override { after_ = callback; }
void SetInputsMap(std::unordered_map<abstract::Tensor *, abstract::Tensor *> *input_isolate_map) {
input_isolate_map_ = input_isolate_map;
}
std::unordered_map<Tensor *, Tensor *> *GetInputMap() { return input_isolate_map_; }
std::unordered_map<abstract::Tensor *, abstract::Tensor *> *GetInputMap() { return input_isolate_map_; }
void SetOutputsMap(std::unordered_map<Tensor *, Tensor *> *output_isolate_map) {
void SetOutputsMap(std::unordered_map<abstract::Tensor *, abstract::Tensor *> *output_isolate_map) {
output_isolate_map_ = output_isolate_map;
}
std::unordered_map<Tensor *, Tensor *> *GetOutputMap() { return output_isolate_map_; }
std::unordered_map<abstract::Tensor *, abstract::Tensor *> *GetOutputMap() { return output_isolate_map_; }
std::vector<abstract::Kernel *> ToKernelList() override;
private:
std::vector<std::shared_ptr<abstract::ExecutionFlow>> execution_flows_;
FuncGraphPtr func_graph_;
std::vector<abstract::Tensor *> inputs_;
std::vector<abstract::Tensor *> outputs_;
std::unordered_map<Tensor *, Tensor *> *input_isolate_map_ = nullptr;
std::unordered_map<Tensor *, Tensor *> *output_isolate_map_ = nullptr;
std::shared_ptr<abstract::Context> context_;
abstract::KernelCallBack before_;
abstract::KernelCallBack after_;
std::unordered_map<abstract::Tensor *, abstract::Tensor *> *input_isolate_map_ = nullptr;
std::unordered_map<abstract::Tensor *, abstract::Tensor *> *output_isolate_map_ = nullptr;
};
} // namespace mindspore::infer

View File

@ -18,7 +18,9 @@
#include <algorithm>
#include "extendrt/graph_compiler/default_graph_compiler.h"
#include "extendrt/graph_compiler/factory.h"
#include "extendrt/mock/lite_runtime/converters.h"
#include "backend/graph_compiler/graph_partition.h"
#include "backend/graph_compiler/segment_runner.h"
#include "common/log.h"
@ -32,6 +34,8 @@ static constexpr auto ms_infer_backend_name = "mindspore_lite_backend";
std::shared_ptr<abstract::ExecutionPlan> DefaultGraphCompiler::Compile(FuncGraphPtr graph) {
MS_LOG(INFO) << "DefaultGraphCompiler::Compile";
inner_context_ = ContextUtils::Convert(context_.get());
MS_LOG(DEBUG) << "DefaultGraphCompiler::Partition Partition FunctionGraph Begin";
auto graph_segments = Partition(graph);
if (graph_segments.empty()) {
@ -102,6 +106,7 @@ std::shared_ptr<abstract::ExecutionPlan> DefaultGraphCompiler::Schedule(
return nullptr;
}
execution_plan->SetOutputs(graph_output_tensor);
execution_plan->SetContext(inner_context_);
for (auto graph_segment : graph_segments) {
FuncGraphPtr fg = nullptr;
@ -115,6 +120,7 @@ std::shared_ptr<abstract::ExecutionPlan> DefaultGraphCompiler::Schedule(
delete output_isolate_map;
return nullptr;
}
execution_flow->SetContext(inner_context_);
for (auto i = 0; i < execution_flow->GetInputs().size(); i++) {
auto input_tensor = execution_flow->GetInputs()[i];
@ -231,4 +237,10 @@ std::shared_ptr<abstract::ExecutionFlow> DefaultGraphCompiler::Schedule(const Gr
// implementation by hangangqiang
return nullptr;
}
static std::shared_ptr<InferSession> DefaultGraphCompilerCreator(const std::shared_ptr<Context> &ctx) {
auto graph_compiler = std::make_shared<DefaultGraphCompiler>(ctx);
return graph_compiler;
}
REG_GRAPH_COMPILER(kDefaultCompiler, DefaultGraphCompilerCreator);
} // namespace mindspore

View File

@ -20,11 +20,14 @@
#include <vector>
#include "infer/graph_compiler.h"
#include "infer/context.h"
namespace mindspore {
class DefaultGraphCompiler : public mindspore::infer::abstract::GraphCompiler {
public:
DefaultGraphCompiler() {}
explicit DefaultGraphCompiler(const std::shared_ptr<Context> &context) : context_(context) {
inner_context_ = nullptr;
}
virtual ~DefaultGraphCompiler() = default;
std::shared_ptr<abstract::ExecutionPlan> Compile(FuncGraphPtr graph) override;
@ -48,6 +51,8 @@ class DefaultGraphCompiler : public mindspore::infer::abstract::GraphCompiler {
private:
mindspore::HashMap<AnfNodePtr, infer::abstract::Tensor *> anf_tensor_map_;
const std::shared_ptr<Context> &context_;
std::shared_ptr<mindspore::infer::abstract::Context> inner_context_;
}
} // namespace mindspore

View File

@ -27,11 +27,12 @@ void GraphCompilerRegistry::RegCompiler(const mindspore::GraphCompilerType &type
graph_compiler_map_[type] = creator;
}
std::shared_ptr<infer::GraphCompiler> GraphCompilerRegistry::GetCompiler(const mindspore::GraphCompilerType &type) {
std::shared_ptr<infer::abstract::GraphCompiler> GraphCompilerRegistry::GetCompiler(
const mindspore::GraphCompilerType &type, const std::shared_ptr<Context> &context) {
auto it = graph_compiler_map_.find(type);
if (it == graph_compiler_map_.end()) {
return nullptr;
}
return it->second();
return it->second(context);
}
} // namespace mindspore

View File

@ -20,11 +20,12 @@
#include <memory>
#include "extendrt/graph_compiler/type.h"
#include "include/api/context.h"
#include "infer/graph_compiler.h"
namespace mindspore {
using GraphCompiler = infer::abstract::GraphCompiler;
using GraphCompilerRegFunc = std::function<std::shared_ptr<GraphCompiler>()>;
using GraphCompilerRegFunc =
std::function<std::shared_ptr<infer::abstract::GraphCompiler>(const std::shared_ptr<Context> &)>;
class GraphCompilerRegistry {
public:
@ -35,7 +36,8 @@ class GraphCompilerRegistry {
void RegCompiler(const mindspore::GraphCompilerType &graph_compiler_type, const GraphCompilerRegFunc &creator);
std::shared_ptr<GraphCompiler> GetCompiler(const mindspore::GraphCompilerType &type);
std::shared_ptr<infer::abstract::GraphCompiler> GetCompiler(const mindspore::GraphCompilerType &type,
const std::shared_ptr<Context> &context);
private:
mindspore::HashMap<mindspore::GraphCompilerType, GraphCompilerRegFunc> graph_compiler_map_;
@ -44,7 +46,7 @@ class GraphCompilerRegistry {
class GraphCompilerRegistrar {
public:
GraphCompilerRegistrar(const mindspore::GraphCompilerType &graph_compiler_type, const GraphCompilerRegFunc &creator) {
GraphCompilerRegistry::GetInstance().GetGraphCompiler(graph_compiler_type, creator);
GraphCompilerRegistry::GetInstance().RegCompiler(graph_compiler_type, creator);
}
~GraphCompilerRegistrar() = default;
};

View File

@ -16,9 +16,6 @@
#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMIPLER_TYPE_H_
#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMIPLER_TYPE_H_
#include <memory>
#include <vector>
namespace mindspore {
enum GraphCompilerType { kDefaultCompiler = 0, kSingleOpSession, kLiteInferSession, kDelegateSession, kNoneCompiler };
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "extendrt/graph_executor/factory.h"
#include <functional>
#include <memory>
namespace mindspore {
GraphExecutorRegistry &GraphExecutorRegistry::GetInstance() {
static GraphExecutorRegistry instance;
return instance;
}
void GraphExecutorRegistry::RegExecutor(const mindspore::GraphExecutorType &type, const GraphExecutorRegFunc &creator) {
graph_executor_map_[type] = creator;
}
std::shared_ptr<infer::GraphExecutor> GraphExecutorRegistry::GetExecutor(const mindspore::GraphExecutorType &type) {
auto it = graph_executor_map_.find(type);
if (it == graph_executor_map_.end()) {
return nullptr;
}
return it->second();
}
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_EXECUTOR_FACTORY_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_EXECUTOR_FACTORY_H_
#include <functional>
#include <memory>
#include <string>
#include "extendrt/graph_executor/type.h"
#include "infer/executor.h"
#include "infer/execution_plan.h"
namespace mindspore {
using GraphExecutorRegFunc = std::function<std::shared_ptr<infer::abstract::Executor>(
const std::string &name, std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan)>;
class GraphExecutorRegistry {
public:
GraphExecutorRegistry() = default;
virtual ~GraphExecutorRegistry() = default;
static GraphExecutorRegistry &GetInstance();
void RegExecutor(const GraphExecutorType &type, const GraphExecutorRegFunc &creator);
std::shared_ptr<infer::abstract::Executor> GetExecutor(
const mindspore::GraphExecutorType &type, const std::string &name,
std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan);
private:
mindspore::HashMap<GraphExecutorType, GraphExecutorRegFunc> graph_executor_map_;
};
class GraphExecutorRegistrar {
public:
GraphExecutorRegistrar(const mindspore::GraphExecutorType &type, const GraphExecutorRegFunc &creator) {
GraphExecutorRegistry::GetInstance().RegExecutor(type, creator);
}
~GraphExecutorRegistrar() = default;
};
#define REG_GRAPH_EXECUTOR(type, creator) static GraphExecutorRegistrar g_##type##GraphExecutor(type, creator);
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_EXECUTOR_FACTORY_H_

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "extendrt/flow_executor.h"
#include "extendrt/graph_executor/flow_executor.h"
#include "extendrt/execution_plan.h"
#include "litert/mindrt_executor.h"

View File

@ -35,7 +35,7 @@ class FlowExecutor : public mindspore::infer::abstract::Executor {
const std::string &Name() override { return name_; }
Status Prepare(std::shared_ptr<abstract::ExecutionFlow> execution_flow) override;
Status Prepare() override;
Status Execute() override;

View File

@ -0,0 +1,84 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "extendrt/graph_executor/mindrt_graph_executor.h"
#include "src/common/log.h"
#include "extendrt/graph_executor/factory.h"
#include "litert/mindrt_executor.h"
#include "extendrt/execution_plan.h"
namespace mindspore {
MindRTGraphExecutor::MindRTGraphExecutor() {
name_ = "";
execution_plan_ = nullptr;
}
MindRTGraphExecutor::MindRTGraphExecutor(const std::string &name,
std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan) {
name_ = name;
execution_plan_ = execution_plan;
auto infer_execution_plan = std::dynamic_pointer_cast<infer::ExecutionPlan>(execution_plan_);
if (infer_execution_plan == nullptr) {
MS_LOG(ERROR) << "MindRTGraphExecutor::MindRTGraphExecutor Not Supported execution plan is passed";
} else {
mindrt_executor_ = std::make_shared<mindspore::lite::MindrtExecutor>(infer_execution_plan->GetInputMap(),
infer_execution_plan->GetOutputMap());
}
}
Status MindRTGraphExecutor::Prepare() {
if (mindrt_executor_ == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::Prepare executor is nullptr";
return kLiteError;
}
if (execution_plan_ == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::Prepare execution plan is nullptr";
return kLiteError;
}
return mindrt_executor_->Prepare(execution_plan_->ToKernelList(), execution_plan_->GetInputs(),
execution_plan_->GetOutputs(), execution_plan_->GetContext().get());
}
Status MindRTGraphExecutor::Execute() {
if (mindrt_executor_ == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::Execute executor is nullptr";
return kLiteError;
}
if (execution_plan_ == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::Execute execution plan is nullptr";
return kLiteError;
}
return mindrt_executor_->Run(execution_plan_->GetInputs(), execution_plan_->GetOutputs(),
execution_plan_->ToKernelList(), execution_plan_->GetKernelBeforeCallBack(),
execution_plan_->GetKernelAfterCallBack());
}
int MindRTGraphExecutor::Resize(const std::vector<infer::abstract::Tensor *> &inputs,
const std::vector<std::vector<int>> &dims) {
if (mindrt_executor_ == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::Resize executor is nullptr";
return kLiteError;
}
return mindrt_executor_->Resize(inputs, dims);
}
static std::shared_ptr<infer::abstract::Executor> MindRTGraphExecutorCreator(
const std::string &name, std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan) {
auto graph_executor = std::make_shared<MindRTGraphExecutor>(name, execution_plan);
return graph_executor;
}
REG_GRAPH_EXECUTOR(kMindRTExecutor, MindRTGraphExecutorCreator);
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_
#include <vector>
#include <memory>
#include <string>
#include "infer/executor.h"
#include "infer/execution_plan.h"
#include "litert/executor.h"
namespace mindspore {
class MindRTGraphExecutor : public mindspore::infer::abstract::Executor {
public:
MindRTGraphExecutor();
explicit MindRTGraphExecutor(const std::string &name, std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan);
virtual ~MindRTGraphExecutor() = default;
const std::string &Name() override { return name_; }
Status Prepare() override;
Status Execute() override;
int Resize(const std::vector<infer::abstract::Tensor *> &inputs, const std::vector<std::vector<int>> &dims) override;
private:
std::string name_;
std::shared_ptr<mindspore::lite::Executor> mindrt_executor_;
std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan_;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_

View File

@ -0,0 +1,70 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "extendrt/graph_executor/plan_executor.h"
#include "extendrt/execution_plan.h"
#include "litert/mindrt_executor.h"
namespace mindspore::infer {
PlanExecutor::PlanExecutor() { PlanExecutor("PlanExecutor"); }
PlanExecutor::PlanExecutor(const std::string &name, std::shared_ptr<abstract::ExecutionPlan> execution_plan) {
name_ = name;
execution_plan_ = execution_plan;
auto infer_execution_plan = std::dynamic_pointer_cast<infer::ExecutionPlan>(execution_plan_);
if (infer_execution_plan == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::FlowExecutor Not Supported execution plan is passed";
} else {
executor_ = std::make_shared<mindspore::lite::MindrtExecutor>(infer_execution_plan->GetInputMap(),
infer_execution_plan->GetOutputMap());
}
}
Status PlanExecutor::Prepare() {
if (executor_ == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::Prepare executor is nullptr";
return kLiteError;
}
if (execution_plan_ == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::Prepare execution plan is nullptr";
return kLiteError;
}
return executor_->Prepare(execution_plan_->ToKernelList(), execution_plan_->GetInputs(),
execution_plan_->GetOutputs(), execution_plan_->GetContext());
}
Status PlanExecutor::Execute() {
if (executor_ == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::Execute executor is nullptr";
return kLiteError;
}
if (execution_plan_ == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::Execute execution plan is nullptr";
return kLiteError;
}
return executor_->Run(execution_plan_->GetInputs(), execution_plan_->GetOutputs(), execution_plan_->ToKernelList(),
execution_plan_->GetKernelBeforeCallBack(), execution_plan_->GetKernelAfterCallBack());
}
int PlanExecutor::Resize(const std::vector<abstract::Tensor *> &inputs, const std::vector<std::vector<int>> &dims) {
if (executor_ == nullptr) {
MS_LOG(ERROR) << "FlowExecutor::Resize executor is nullptr";
return kLiteError;
}
return executor_->Resize(inputs, dims);
}
} // namespace mindspore::infer

View File

@ -0,0 +1,51 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_PLAN_EXECUTOR_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_PLAN_EXECUTOR_H_
#include <memory>
#include <vector>
#include <string>
#include "infer/executor.h"
#include "infer/execution_plan.h"
#include "litert/executor.h"
namespace mindspore::infer {
class PlanExecutor : public mindspore::infer::abstract::Executor {
public:
PlanExecutor();
// explicit FlowExecutor(const std::string &name);
explicit PlanExecutor(const std::string &name);
virtual ~PlanExecutor() = default;
const std::string &Name() override { return name_; }
Status Prepare() override;
Status Execute() override;
int Resize(const std::vector<abstract::Tensor *> &inputs, const std::vector<std::vector<int>> &dims) override;
private:
std::string name_;
std::shared_ptr<mindspore::lite::Executor> executor_;
std::shared_ptr<abstract::ExecutionPlan> execution_plan_;
};
} // namespace mindspore::infer
#endif // MINDSPORE_LITE_SRC_EXTENDRT_PLAN_EXECUTOR_H_

View File

@ -0,0 +1,22 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_EXECUTOR_TYPE_H_
#define MINDSPORE_LITE_EXTENDRT_GRAPH_EXECUTOR_TYPE_H_
namespace mindspore {
enum GraphExecutorType { kDefaultExecutor = 0, kMindRTExecutor, kNoneExcutor };
} // namespace mindspore
#endif // MINDSPORE_LITE_EXTENDRT_GRAPH_EXECUTOR_TYPE_H_

View File

@ -15,8 +15,9 @@
*/
#include "extendrt/graph_runtime/default_graph_runtime.h"
#include "extendrt/flow_executor.h"
#include "extendrt/graph_executor/plan_executor.h"
#include "src/common/log.h"
#include "extendrt/graph_runtime/factory.h"
namespace mindspore {
using ExecutionPlan = mindspore::infer::abstract::ExecutionPlan;
@ -30,20 +31,20 @@ Status DefaultGraphRuntime::Prepare(std::shared_ptr<ExecutionPlan> execution_pla
}
execution_plan_ = execution_plan;
for (auto execution_flow : execution_plan->GetExecutionFLows()) {
auto executor = SelectExecutor(execution_flow);
if (executor == nullptr) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Prepare Select Executor is nullptr.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Prepare Prepare Execution Plan Begin of Executor " << executor->Name();
auto status = executor->Prepare(execution_flow);
if (status != kSuccess) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Prepare Prepare Execution Plan Failed in Executor " << executor->Name();
return kLiteError;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Prepare Prepare Execution Plan End";
auto executor = SelectExecutor();
if (executor == nullptr) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Prepare Select Executor is nullptr.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Prepare Prepare Execution Plan Begin of Executor " << executor->Name();
auto status = executor->Prepare(nullptr);
if (status != kSuccess) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Prepare Prepare Execution Plan Failed in Executor " << executor->Name();
return kLiteError;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Prepare Prepare Execution Plan End";
MS_LOG(INFO) << "AbstractRuntime::Prepare End";
return kSuccess;
}
@ -56,27 +57,27 @@ Status DefaultGraphRuntime::Execute() {
return kLiteNullptr;
}
for (auto execution_flow : execution_plan_->GetExecutionFLows()) {
auto executor = SelectExecutor(execution_flow);
if (executor == nullptr) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Select Executor is nullptr.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execution Plan Begin of Executor " << executor->Name();
auto status = executor->Execute();
if (status != kSuccess) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Execution Plan Failed in Executor " << executor->Name();
return kLiteError;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Prepare Execution Plan End";
auto executor = SelectExecutor();
if (executor == nullptr) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Select Executor is nullptr.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execute Execution Plan Begin of Executor " << executor->Name();
auto status = executor->Execute();
if (status != kSuccess) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Execute Execution Plan Failed in Executor " << executor->Name();
return kLiteError;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execute Execution Plan End";
MS_LOG(INFO) << "DefaultGraphRuntime::Execute End";
return kSuccess;
}
Status DefaultGraphRuntime::Execute(const std::vector<abstract::Tensor *> &inputs,
const std::vector<abstract::Tensor *> &outputs, abstract::KernelCallBack before,
abstract::KernelCallBack after) {
Status DefaultGraphRuntime::Execute(const std::vector<infer::abstract::Tensor *> &inputs,
const std::vector<infer::abstract::Tensor *> &outputs,
infer::abstract::KernelCallBack before, infer::abstract::KernelCallBack after) {
MS_LOG(INFO) << "DefaultGraphRuntime::Execute Begin";
if (execution_plan_ == nullptr) {
@ -84,37 +85,65 @@ Status DefaultGraphRuntime::Execute(const std::vector<abstract::Tensor *> &input
return kLiteNullptr;
}
for (auto &execution_flow : execution_plan_->GetExecutionFLows()) {
auto executor = SelectExecutor(execution_flow);
if (executor == nullptr) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Select Executor is nullptr.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execution Plan Begin of Executor " << executor->Name();
execution_flow->SetInputs(inputs);
execution_flow->SetOutputs(outputs);
execution_flow->SetKernelBeforeCallBack(before);
execution_flow->SetKernelAfterCallBack(after);
auto status = executor->Execute();
if (status != kSuccess) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Execution Plan Failed in Executor " << executor->Name();
return kLiteError;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Prepare Execution Plan End";
auto executor = SelectExecutor();
if (executor == nullptr) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Select Executor is nullptr.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execute Execution Plan Begin of Executor " << executor->Name();
execution_plan_->SetInputs(inputs);
execution_plan_->SetOutputs(outputs);
execution_plan_->SetKernelBeforeCallBack(before);
execution_plan_->SetKernelAfterCallBack(after);
auto status = executor->Execute();
if (status != kSuccess) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Execute Execution Plan Failed in Executor " << executor->Name();
return kLiteError;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execute Execution Plan End";
MS_LOG(INFO) << "DefaultGraphRuntime::Execute End";
return kSuccess;
}
std::shared_ptr<abstract::Executor> DefaultGraphRuntime::SelectExecutor(
const std::shared_ptr<abstract::ExecutionFlow> &execution_flow) {
auto it = executor_map_.find(execution_flow);
if (it == executor_map_.end()) {
// create a new executor for execution flow
auto executor = std::make_shared<infer::FlowExecutor>("flow-executor");
executor_map_[execution_flow] = executor;
return executor;
Status DefaultGraphRuntime::Resize(const std::vector<infer::abstract::Tensor *> *inputs,
const std::vector<std::vector<int64_t>> &dims) {
MS_LOG(INFO) << "DefaultGraphRuntime::Resize Begin";
if (execution_plan_ == nullptr) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Resize Execution Plan is nullptr.";
return kLiteNullptr;
}
return it->second;
auto executor = SelectExecutor();
if (executor == nullptr) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Resize Select Executor is nullptr.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Resize Resize Execution Plan Begin of Executor " << executor->Name();
auto status = executor->Resize(inputs, dims);
if (status != kSuccess) {
MS_LOG(ERROR) << "DefaultGraphRuntime::Resize Resize Execution Plan Failed in Executor " << executor->Name();
return kLiteError;
}
MS_LOG(DEBUG) << "DefaultGraphRuntime::Resize Resize Execution Plan End";
MS_LOG(INFO) << "DefaultGraphRuntime::Resize End";
return kSuccess;
}
std::shared_ptr<infer::abstract::Executor> DefaultGraphRuntime::SelectExecutor() {
if (default_executor_ == nullptr) {
default_executor_ = std::make_shared<infer::PlanExecutor>("plan-executor");
}
return default_executor_;
}
static std::shared_ptr<InferSession> DefaultGraphRuntimeCreator() {
auto graph_runtime = std::make_shared<DefaultGraphRuntime>();
return graph_runtime;
}
REG_GRAPH_RUNTIME(kDefaultRuntime, DefaultGraphRuntimeCreator);
} // namespace mindspore

View File

@ -27,19 +27,24 @@ class DefaultGraphRuntime : public mindspore::infer::abstract::GraphRuntime {
DefaultGraphRuntime() = default;
virtual ~DefaultGraphRuntime() = default;
Status Prepare(std::shared_ptr<abstract::ExecutionPlan> execution_plan) override;
Status Prepare(std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan) override;
Status Execute() override;
Status Execute(const std::vector<abstract::Tensor *> &inputs, const std::vector<abstract::Tensor *> &outputs,
abstract::KernelCallBack before = nullptr, abstract::KernelCallBack after = nullptr) override;
Status Execute(const std::vector<infer::abstract::Tensor *> &inputs,
const std::vector<infer::abstract::Tensor *> &outputs,
infer::abstract::KernelCallBack before = nullptr,
infer::abstract::KernelCallBack after = nullptr) override;
Status Resize(const std::vector<infer::abstract::Tensor *> *inputs,
const std::vector<std::vector<int64_t>> &dims) override;
private:
std::shared_ptr<abstract::Executor> SelectExecutor(const std::shared_ptr<abstract::ExecutionFlow> &execution_flow);
std::shared_ptr<infer::abstract::Executor> SelectExecutor();
private:
std::shared_ptr<abstract::ExecutionPlan> execution_plan_ = nullptr;
mindspore::HashMap<std::shared_ptr<abstract::ExecutionFlow>, std::shared_ptr<abstract::Executor>> executor_map_;
std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan_ = nullptr;
std::shared_ptr<infer::abstract::Executor> default_executor_ = nullptr;
};
} // namespace mindspore

View File

@ -18,16 +18,17 @@
#include <memory>
namespace mindspore {
GraphRuntimRegistry &GraphRuntimRegistry::GetInstance() {
static GraphRuntimRegistry instance;
GraphRuntimeRegistry &GraphRuntimeRegistry::GetInstance() {
static GraphRuntimeRegistry instance;
return instance;
}
void GraphRuntimRegistry::RegRuntime(const mindspore::GraphRuntimeType &type, const GraphRuntimeRegFunc &creator) {
void GraphRuntimeRegistry::RegRuntime(const mindspore::GraphRuntimeType &type, const GraphRuntimeRegFunc &creator) {
graph_runtime_map_[type] = creator;
}
std::shared_ptr<infer::GraphRuntime> GraphRuntimRegistry::GetRuntime(const mindspore::GraphRuntimeType &type) {
std::shared_ptr<infer::abstract::GraphRuntime> GraphRuntimeRegistry::GetRuntime(
const mindspore::GraphRuntimeType &type) {
auto it = graph_runtime_map_.find(type);
if (it == graph_runtime_map_.end()) {
return nullptr;

View File

@ -23,19 +23,18 @@
#include "infer/graph_runtime.h"
namespace mindspore {
using GraphRuntime = infer::abstract::GraphRuntime;
using GraphRuntimeRegFunc = std::function<std::shared_ptr<GraphRuntime>()>;
using GraphRuntimeRegFunc = std::function<std::shared_ptr<infer::abstract::GraphRuntime>()>;
class GraphRuntimRegistry {
class GraphRuntimeRegistry {
public:
GraphRuntimRegistry() = default;
virtual ~GraphRuntimRegistry() = default;
GraphRuntimeRegistry() = default;
virtual ~GraphRuntimeRegistry() = default;
static GraphRuntimRegistry &GetInstance();
static GraphRuntimeRegistry &GetInstance();
void RegRuntime(const GraphRuntimeType &type, const GraphRuntimeRegFunc &creator);
std::shared_ptr<GraphRuntime> GetRuntime(const mindspore::GraphRuntimeType &type);
std::shared_ptr<infer::abstract::GraphRuntime> GetRuntime(const mindspore::GraphRuntimeType &type);
private:
mindspore::HashMap<GraphRuntimeType, GraphRuntimeRegFunc> graph_runtime_map_;
@ -44,7 +43,7 @@ class GraphRuntimRegistry {
class GraphRuntimeRegistrar {
public:
GraphRuntimeRegistrar(const mindspore::GraphRuntimeType &type, const GraphRuntimeRegFunc &creator) {
GraphRuntimRegistry::GetInstance().RegRuntime(type, creator);
GraphRuntimeRegistry::GetInstance().RegRuntime(type, creator);
}
~GraphRuntimeRegistrar() = default;
};

View File

@ -16,9 +16,6 @@
#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_RUNTIME_TYPE_H_
#define MINDSPORE_LITE_EXTENDRT_GRAPH_RUNTIME_TYPE_H_
#include <memory>
#include <vector>
namespace mindspore {
enum GraphRuntimeType { kDefaultRuntime = 0, kSingleOpSession, kLiteInferSession, kDelegateSession, kNoneRuntime };
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include "extendrt/session/factory.h"
#include "extendrt/graph_compiler/factory.h"
#include "extendrt/graph_runtime/factory.h"
#include "extendrt/utils/tensor_utils.h"
#include "backend/graph_compiler/graph_partition.h"
#include "litert/cxx_api/tensor/tensor_impl.h"
@ -32,18 +33,18 @@ static const std::vector<PrimitivePtr> ms_infer_cut_list = {prim::kPrimReturn,
prim::kPrimBpropCut, prim::kPrimSwitchLayer};
Status DefaultInferSession::Init(const std::shared_ptr<Context> &context) {
MS_LOG(INFO) << "DefaultInferSession::Init";
context_ = context;
// context_ = context;
// Set MSContext::GetInstance param?
// init compiler and runtime according to context
compiler_ = GraphCompilerRegistry::GetInstance()->GetCompiler(kDefaultCompiler);
compiler_ = GraphCompilerRegistry::GetInstance().GetCompiler(kDefaultCompiler, context_);
if (compiler_ == nullptr) {
MS_LOG(ERROR) << "DefaultInferSession::Init Get Compiler is nullptr";
return kLiteNullptr;
}
runtime_ = GraphRuntimRegistry::GetInstance()->GetRuntime(kDefaultRuntime);
runtime_ = GraphRuntimeRegistry::GetInstance().GetRuntime(kDefaultRuntime);
if (runtime_ == nullptr) {
MS_LOG(ERROR) << "DefaultInferSession::Init Get Runtime is nullptr";
return kLiteNullptr;
@ -68,7 +69,7 @@ Status DefaultInferSession::CompileGraph(FuncGraphPtr graph, const void *data, s
MS_LOG(DEBUG) << "DefaultInferSession::CompileGraph Compile Graph End";
MS_LOG(DEBUG) << "DefaultInferSession::CompileGraph Prepare ExecutionPlan Begin";
auto runtime = this->GetRuntime();
auto runtime = this->GetGraphRuntime();
if (runtime == nullptr) {
MS_LOG(ERROR) << "DefaultInferSession::CompileGraph Runtime in Infer Session is null";
return kLiteNullptr;
@ -86,7 +87,7 @@ Status DefaultInferSession::CompileGraph(FuncGraphPtr graph, const void *data, s
Status DefaultInferSession::RunGraph(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
MS_LOG(DEBUG) << "DefaultInferSession::RunGraph Execute ExecutionPlan Begin";
auto runtime = this->GetRuntime();
auto runtime = this->GetGraphRuntime();
if (runtime_ == nullptr) {
MS_LOG(ERROR) << "DefaultInferSession::RunGraph Runtime in Infer Session is null";
return kLiteNullptr;
@ -101,7 +102,7 @@ Status DefaultInferSession::RunGraph(const std::vector<tensor::Tensor> &inputs,
MS_LOG(ERROR) << "DefaultInferSession::RunGraph Copy Data Pointer to input tensors failed";
return status;
}
status = CopyDataToInnerTensors(outputs, inner_outputs);
status = CopyDataToInnerTensors(*outputs, inner_outputs);
if (status != kSuccess) {
MS_LOG(ERROR) << "DefaultInferSession::RunGraph Copy Data Pointer to output tensors failed";
return status;
@ -124,11 +125,16 @@ Status DefaultInferSession::RunGraph(const std::vector<tensor::Tensor> &inputs,
return RunGraph(inputs, outputs, nullptr, nullptr);
}
Status DefaultInferSession::Resize(const std::vector<tensor::Tensor> &inputs,
const std::vector<std::vector<int64_t>> &dims) {
return kSuccess;
}
std::vector<MutableTensorImplPtr> DefaultInferSession::GetOutputs() {
auto runtime = this->GetRuntime();
auto runtime = this->GetGraphRuntime();
if (runtime_ == nullptr) {
MS_LOG(ERROR) << "DefaultInferSession::GetOutputs Runtime in Infer Session is null";
return kLiteNullptr;
return std::vector<MutableTensorImplPtr>{};
}
auto lite_outputs = runtime->GetOutputs();
MS_LOG(DEBUG) << "DefaultInferSession::GetOutputs end";
@ -136,10 +142,10 @@ std::vector<MutableTensorImplPtr> DefaultInferSession::GetOutputs() {
}
std::vector<MutableTensorImplPtr> DefaultInferSession::GetInputs() {
auto runtime = this->GetRuntime();
auto runtime = this->GetGraphRuntime();
if (runtime_ == nullptr) {
MS_LOG(ERROR) << "DefaultInferSession::GetOutputs Runtime in Infer Session is null";
return kLiteNullptr;
return std::vector<MutableTensorImplPtr>{};
}
auto lite_inputs = runtime->GetInputs();
MS_LOG(DEBUG) << "DefaultInferSession::GetOutputs end";
@ -152,7 +158,7 @@ MutableTensorImplPtr DefaultInferSession::GetOutputByTensorName(const std::strin
MutableTensorImplPtr DefaultInferSession::GetInputByTensorName(const std::string &name) { return nullptr; }
Status DefaultInferSession::CopyDataToInnerTensors(const std::vector<tensor::Tensor> &tensors,
std::vector<abstract::Tensor *> inner_tensors) {
std::vector<infer::abstract::Tensor *> inner_tensors) {
if (tensors.size() == inner_tensors.size()) {
MS_LOG(EXCEPTION) << "user input size " << tensors.size() << " is not equal to graphp input size "
<< inner_tensors.size();
@ -200,17 +206,17 @@ Status DefaultInferSession::CopyDataToInnerTensors(const std::vector<tensor::Ten
return kSuccess;
}
std::vector<MutableTensorImplPtr> &DefaultInferSession::AbstractTensorsToTensorImpls(
const std::vector<abstract::Tensor *> &abstract_tensors) {
std::vector<std::shared_ptr<LiteTensorImpl>> tensorImpls;
std::vector<MutableTensorImplPtr> DefaultInferSession::AbstractTensorsToTensorImpls(
const std::vector<infer::abstract::Tensor *> &abstract_tensors) {
std::vector<MutableTensorImplPtr> tensorImpls;
tensorImpls.reserve(abstract_tensors.size());
(void)std::transform(abstract_tensors.begin(), abstract_tensors.end(), std::back_inserter(tensorImpls),
[](abstract::Tensor *tensor) { return std::make_shared<LiteTensorImpl>(tensor); });
[](infer::abstract::Tensor *tensor) { return std::make_shared<LiteTensorImpl>(tensor); });
return tensorImpls;
}
std::vector<mindspore::tensor::Tensor> DefaultInferSession::LiteTensorToTensor(
const std::vector<abstract::Tensor *> &abstract_tensors) {
const std::vector<infer::abstract::Tensor *> &abstract_tensors) {
std::vector<mindspore::tensor::Tensor> tensors;
for (auto abstract_tensor : abstract_tensors) {
if (abstract_tensor == nullptr) {
@ -222,11 +228,14 @@ std::vector<mindspore::tensor::Tensor> DefaultInferSession::LiteTensorToTensor(
auto data = abstract_tensor->MutableData();
auto data_size = abstract_tensor->Size();
auto ref_tensor_data =
std::make_shared<TensorRefData>(data, abstract_tensor->ElementNum(), data_size, shape.size());
mindspore::tensor::Tensor tensor(type_id, shape, ref_tensor_data);
std::make_shared<TensorRefData>(data, abstract_tensor->ElementsNum(), data_size, shape.size());
std::vector<int64_t> shape64;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape64),
[](int dim) { return static_cast<int64_t>(dim); });
mindspore::tensor::Tensor tensor(type_id, shape64, ref_tensor_data);
auto device_address = abstract_tensor->device_data();
if (device_address != nullptr) {
auto lite_device_address = std::make_shared<LiteDeviceAddress>(device_address, abstract_tensor->DataSize());
auto lite_device_address = std::make_shared<LiteDeviceAddress>(device_address, abstract_tensor->Size());
tensor.set_device_address(lite_device_address);
}
tensors.emplace_back(std::move(tensor));
@ -234,6 +243,34 @@ std::vector<mindspore::tensor::Tensor> DefaultInferSession::LiteTensorToTensor(
return tensors;
}
std::vector<int32_t> DefaultInferSession::TruncateShape(const std::vector<int64_t> &shape, enum TypeId type,
size_t data_len, bool verify_size) {
std::vector<int32_t> empty;
if (shape.empty()) {
return empty;
}
std::vector<int32_t> truncated_shape;
truncated_shape.resize(shape.size());
size_t element_size = lite::DataTypeSize(type);
for (size_t i = 0; i < shape.size(); i++) {
auto dim = shape[i];
if (dim < 0 || dim > INT_MAX || (dim != 0 && element_size > INT_MAX / static_cast<size_t>(dim))) {
MS_LOG(ERROR) << "Invalid shape!dim: " << dim << ", element_size: " << element_size;
return empty;
} else {
element_size *= static_cast<size_t>(dim);
truncated_shape[i] = static_cast<int32_t>(dim);
}
}
if (verify_size) {
if (element_size != data_len) {
MS_LOG(ERROR) << "Invalid data size!element_size: " << element_size << ", data_len: " << data_len;
return empty;
}
}
return truncated_shape;
}
static std::shared_ptr<InferSession> DefaultSessionCreator(const std::shared_ptr<Context> &ctx,
const ConfigInfos &config_infos) {
auto session = std::make_shared<DefaultInferSession>(ctx);

View File

@ -37,6 +37,7 @@ class DefaultInferSession : public InferSession {
Status RunGraph(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs) override;
Status RunGraph(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) override;
Status Resize(const std::vector<tensor::Tensor> &inputs, const std::vector<std::vector<int64_t>> &dims) override;
std::vector<MutableTensorImplPtr> GetOutputs() override;
std::vector<MutableTensorImplPtr> GetInputs() override;
std::vector<std::string> GetOutputNames() override;
@ -45,16 +46,19 @@ class DefaultInferSession : public InferSession {
MutableTensorImplPtr GetInputByTensorName(const std::string &name) override;
protected:
virtual std::shared_ptr<infer::abstract::GraphCompiler> GetCompiler() { return compiler_; }
virtual std::shared_ptr<infer::abstract::GraphCompiler> GetGraphCompiler() { return compiler_; }
virtual std::shared_ptr<infer::abstract::GraphRuntime> GetRuntime() { return runtime_; }
virtual std::shared_ptr<infer::abstract::GraphRuntime> GetGraphRuntime() { return runtime_; }
private:
Status CopyDataToInnerTensors(const std::vector<tensor::Tensor> &tensors,
std::vector<abstract::Tensor *> inner_tensors);
std::vector<MutableTensorImplPtr> &AbstractTensorsToTensorImpls(
const std::vector<abstract::Tensor *> &abstract_tensors);
std::vector<mindspore::tensor::Tensor> LiteTensorToTensor(const std::vector<abstract::Tensor *> &abstract_tensors);
std::vector<infer::abstract::Tensor *> inner_tensors);
std::vector<MutableTensorImplPtr> AbstractTensorsToTensorImpls(
const std::vector<infer::abstract::Tensor *> &abstract_tensors);
std::vector<mindspore::tensor::Tensor> LiteTensorToTensor(
const std::vector<infer::abstract::Tensor *> &abstract_tensors);
std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
bool verify_size);
private:
std::shared_ptr<infer::abstract::GraphCompiler> compiler_;

View File

@ -67,14 +67,14 @@ class ExecutionFlow : public std::enable_shared_from_this<ExecutionFlow> {
/// \brief Get context for the execution flow.
///
/// \return Context pointer.
virtual Context *GetContext() = 0;
virtual std::shared_ptr<Context> GetContext() = 0;
/// \brief Set context of execution run
///
/// \param[in] context, context for running
///
/// \return void.
virtual void SetContext(Context *context) = 0;
virtual void SetContext(std::shared_ptr<Context> context) = 0;
/// \brief Get callback before kernel execution.
///
@ -99,6 +99,11 @@ class ExecutionFlow : public std::enable_shared_from_this<ExecutionFlow> {
///
/// \return void.
virtual void SetKernelAfterCallBack(const KernelCallBack &callback) = 0;
/// \brief Construct flow into one fusion Kernel, eg. SubGraphKernel.
///
/// \return Kernel pointer.
virtual Kernel *ConstructFusionKernel() = 0;
};
} // namespace mindspore::infer::abstract

View File

@ -21,6 +21,7 @@
#include "ir/func_graph.h"
#include "infer/execution_flow.h"
#include "infer/context.h"
namespace mindspore::infer::abstract {
class ExecutionPlan : public std::enable_shared_from_this<ExecutionPlan> {
@ -81,6 +82,47 @@ class ExecutionPlan : public std::enable_shared_from_this<ExecutionPlan> {
///
/// \return void.
virtual void SetOutputs(const std::vector<Tensor *> &outputs) = 0;
/// \brief Get context of execution plan.
///
/// \return Context of execution plan.
virtual std::shared_ptr<Context> GetContext() = 0;
/// \brief Set context to run.
///
/// \param[in] context, context
///
/// \return void.
virtual void SetContext(std::shared_ptr<Context> context) = 0;
/// \brief Get callback before kernel execution.
///
/// \return KernelCallBack pointer.
virtual const KernelCallBack &GetKernelBeforeCallBack() = 0;
/// \brief Set callback before kernel execution.
///
/// \param[in] callback, callback function pointer
///
/// \return void.
virtual void SetKernelBeforeCallBack(const KernelCallBack &callback) = 0;
/// \brief Get callback after kernel execution.
///
/// \return KernelCallBack pointer.
virtual const KernelCallBack &GetKernelAfterCallBack() = 0;
/// \brief Set callback after kernel execution.
///
/// \param[in] callback, callback function pointer
///
/// \return void.
virtual void SetKernelAfterCallBack(const KernelCallBack &callback) = 0;
/// \brief Convert Execution Plan to Kernel List
///
/// \return Kernel List
virtual std::vector<Kernel *> ToKernelList() = 0;
};
} // namespace mindspore::infer::abstract

View File

@ -21,7 +21,7 @@
#include <vector>
#include "include/api/status.h"
#include "infer/execution_flow.h"
#include "infer/tensor.h"
namespace mindspore::infer::abstract {
class Executor : public std::enable_shared_from_this<Executor> {
@ -38,7 +38,7 @@ class Executor : public std::enable_shared_from_this<Executor> {
/// \param[in] execution_flow Abstract Execution Plan for execute.
///
/// \return Status.
virtual Status Prepare(std::shared_ptr<ExecutionFlow> execution_flow) = 0;
virtual Status Prepare() = 0;
/// \brief Execute According to ExecutionFlow.
///

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INFER_GRAPH__RUNTIME_H_
#define MINDSPORE_LITE_INFER_GRAPH__RUNTIME_H_
#ifndef MINDSPORE_LITE_INFER_GRAPH_RUNTIME_H_
#define MINDSPORE_LITE_INFER_GRAPH_RUNTIME_H_
#include <vector>
#include <memory>
@ -50,6 +50,14 @@ class GraphRuntime : public std::enable_shared_from_this<GraphRuntime> {
virtual Status Execute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
KernelCallBack before = nullptr, KernelCallBack after = nullptr) = 0;
/// \brief Resize According to New Inputs and dims.
///
/// \param[in] inputs, inputs tensors to resize
/// \param[in] dims, targe dim shape to resize
///
/// \return Status.
virtual Status Resize(const std::vector<Tensor *> *inputs, const std::vector<std::vector<int64_t>> &dims) = 0;
/// \brief Get list of inputs for the model.
///
/// \return vector of Tensor.
@ -62,4 +70,4 @@ class GraphRuntime : public std::enable_shared_from_this<GraphRuntime> {
};
} // namespace mindspore::infer::abstract
#endif // MINDSPORE_LITE_INFER_GRAPH__RUNTIME_H_
#endif // MINDSPORE_LITE_INFER_GRAPH_RUNTIME_H_

View File

@ -13,19 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INFER_CALLBACK_H_
#define MINDSPORE_LITE_INFER_CALLBACK_H_
#ifndef MINDSPORE_LITE_INFER_KERNEL_CALLBACK_H_
#define MINDSPORE_LITE_INFER_KERNEL_CALLBACK_H_
#include <memory>
#include "executor/kernel_exec.h"
namespace mindspore::infer::abstract {
using KernelCallBack = mindspore::lite::KernelCallBack;
// class CallBack : public std::enable_shared_from_this<CallBack> {
// public:
// virtual ~CallBack() = default;
// };
} // namespace mindspore::infer::abstract
#endif // MINDSPORE_LITE_INFER_CALLBACK_H_
#endif // MINDSPORE_LITE_INFER_KERNEL_CALLBACK_H_