forked from mindspore-Ecosystem/mindspore
!49445 [MS]{Lite][Task] optimize new infer framework
Merge pull request !49445 from 刘力力/feature_new_infer_merge
This commit is contained in:
commit
b127cd5015
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "extendrt/execution_flow.h"
|
||||
|
||||
#include "litert/lite_kernel.h"
|
||||
#include "litert/kernel_exec_util.h"
|
||||
#include "litert/sub_graph_kernel.h"
|
||||
|
||||
namespace mindspore::infer {
|
||||
ExecutionFlow::~ExecutionFlow() {
|
||||
for (auto tensor : inputs_) {
|
||||
if (tensor != nullptr) {
|
||||
delete tensor;
|
||||
}
|
||||
}
|
||||
for (auto tensor : outputs_) {
|
||||
if (tensor != nullptr) {
|
||||
delete tensor;
|
||||
}
|
||||
}
|
||||
for (auto kernel : kernels_) {
|
||||
if (kernel != nullptr) {
|
||||
delete kernel;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract::Kernel *ExecutionFlow::ConstructFusionKernel() {
|
||||
auto lite_kernel = new (std::nothrow) mindspore::kernel::LiteKernel(nullptr, inputs_, outputs_, context_);
|
||||
if (lite_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "ExecutionFlow::ConstructFusionKernel create lite kernel failed, may be memory is not enough";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<KernelExec *> input_kernels = mindspore::kernel::KernelExecUtil::SubgraphInputNodes(kernels_);
|
||||
std::vector<KernelExec *> output_kernels = mindspore::kernel::KernelExecUtil::SubgraphOutputNodes(kernels_);
|
||||
|
||||
mindspore::kernel::SubGraphKernel *sub_graph_kernel =
|
||||
new CpuFp32SubGraph(input_kernels, output_kernels, kernels_, lite_kernel);
|
||||
if (sub_graph_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "ExecutionFlow::ConstructFusionKernel create sub graph kernel failed, may be memory is not enough";
|
||||
delete lite_kernel return nullptr;
|
||||
}
|
||||
sub_graph_kernel->set_context(context_);
|
||||
return sub_graph_kernel;
|
||||
}
|
||||
} // namespace mindspore::infer
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_EXTENDRT_EXECUTION_FLOW_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "infer/execution_flow.h"
|
||||
|
||||
|
@ -25,23 +26,7 @@ namespace mindspore::infer {
|
|||
class ExecutionFlow : public abstract::ExecutionFlow {
|
||||
public:
|
||||
ExecutionFlow() = default;
|
||||
virtual ~ExecutionFlow() {
|
||||
for (auto tensor : inputs_) {
|
||||
if (tensor != nullptr) {
|
||||
delete tensor;
|
||||
}
|
||||
}
|
||||
for (auto tensor : outputs_) {
|
||||
if (tensor != nullptr) {
|
||||
delete tensor;
|
||||
}
|
||||
}
|
||||
for (auto kernel : kernels_) {
|
||||
if (kernel != nullptr) {
|
||||
delete kernel;
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual ~ExecutionFlow();
|
||||
|
||||
std::vector<abstract::Kernel *> GetKernels() override { return kernels_; }
|
||||
|
||||
|
@ -55,9 +40,9 @@ class ExecutionFlow : public abstract::ExecutionFlow {
|
|||
|
||||
void SetOutputs(const std::vector<abstract::Tensor *> &outputs) override { outputs_ = outputs; }
|
||||
|
||||
abstract::Context *GetContext() override { return context_; }
|
||||
std::shared_ptr<abstract::Context> GetContext() override { return context_; }
|
||||
|
||||
void SetContext(abstract::Context *context) override { context_ = context; }
|
||||
void SetContext(std::shared_ptr<abstract::Context> context) override { context_ = context; }
|
||||
|
||||
const abstract::KernelCallBack &GetKernelBeforeCallBack() override { return before_; }
|
||||
|
||||
|
@ -67,11 +52,13 @@ class ExecutionFlow : public abstract::ExecutionFlow {
|
|||
|
||||
void SetKernelAfterCallBack(const abstract::KernelCallBack &callback) override { after_ = callback; }
|
||||
|
||||
abstract::Kernel *ConstructFusionKernel() override;
|
||||
|
||||
private:
|
||||
std::vector<abstract::Kernel *> kernels_;
|
||||
std::vector<abstract::Tensor *> inputs_;
|
||||
std::vector<abstract::Tensor *> outputs_;
|
||||
abstract::Context *context_;
|
||||
std::shared_ptr<abstract::Context> context_;
|
||||
abstract::KernelCallBack before_;
|
||||
abstract::KernelCallBack after_;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "extendrt/execution_plan.h"
|
||||
|
||||
#include "litert/lite_kernel.h"
|
||||
#include "litert/kernel_exec_util.h"
|
||||
#include "litert/sub_graph_kernel.h"
|
||||
|
||||
namespace mindspore::infer {
|
||||
ExecutionPlan::~ExecutionPlan() {
|
||||
if (input_isolate_map_ != nullptr) {
|
||||
delete input_isolate_map_;
|
||||
input_isolate_map_ = nullptr;
|
||||
}
|
||||
if (output_isolate_map_) {
|
||||
delete output_isolate_map_;
|
||||
output_isolate_map_ = nullptr;
|
||||
}
|
||||
|
||||
for (auto tensor : inputs_) {
|
||||
if (tensor != nullptr) {
|
||||
delete tensor;
|
||||
}
|
||||
}
|
||||
for (auto tensor : outputs_) {
|
||||
if (tensor != nullptr) {
|
||||
delete tensor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<abstract::Kernel *> ExecutionPlan::ToKernelList() {
|
||||
std::vector<abstract::Kernel *> kernels;
|
||||
for (auto flow : execution_flows_) {
|
||||
if (flow == nullptr) {
|
||||
MS_LOG(ERROR) << "ExecutionPlan::ToKernelList get nullptr execution flow.";
|
||||
return std::vector<abstract::Kernel *>{};
|
||||
}
|
||||
auto kernel = flow->ConstructFusionKernel();
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "ExecutionPlan::ToKernelList construct execution flow to Sub Graph Kernel failed.";
|
||||
return std::vector<abstract::Kernel *>{};
|
||||
}
|
||||
kernels.emplace_back(kernel);
|
||||
}
|
||||
return kernels;
|
||||
}
|
||||
} // namespace mindspore::infer
|
|
@ -27,27 +27,7 @@ namespace mindspore::infer {
|
|||
class ExecutionPlan : public abstract::ExecutionPlan {
|
||||
public:
|
||||
ExecutionPlan() = default;
|
||||
virtual ~ExecutionPlan() {
|
||||
if (input_isolate_map_ != nullptr) {
|
||||
delete input_isolate_map_;
|
||||
input_isolate_map_ = nullptr;
|
||||
}
|
||||
if (output_isolate_map_) {
|
||||
delete output_isolate_map_;
|
||||
output_isolate_map_ = nullptr;
|
||||
}
|
||||
|
||||
for (auto tensor : inputs_) {
|
||||
if (tensor != nullptr) {
|
||||
delete tensor;
|
||||
}
|
||||
}
|
||||
for (auto tensor : outputs_) {
|
||||
if (tensor != nullptr) {
|
||||
delete tensor;
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual ~ExecutionPlan();
|
||||
|
||||
std::vector<std::shared_ptr<abstract::ExecutionFlow>> GetExecutionFLows() override { return execution_flows_; }
|
||||
|
||||
|
@ -55,7 +35,7 @@ class ExecutionPlan : public abstract::ExecutionPlan {
|
|||
execution_flows_ = execution_flows;
|
||||
}
|
||||
|
||||
void AddExecutionFlow(std::shared_ptr<ExecutionFlow> execution_flow) override {
|
||||
void AddExecutionFlow(std::shared_ptr<abstract::ExecutionFlow> execution_flow) override {
|
||||
execution_flows_.emplace_back(execution_flow);
|
||||
}
|
||||
|
||||
|
@ -71,25 +51,42 @@ class ExecutionPlan : public abstract::ExecutionPlan {
|
|||
|
||||
void SetOutputs(const std::vector<abstract::Tensor *> &outputs) override { outputs_ = outputs; }
|
||||
|
||||
void SetInputsMap(std::unordered_map<Tensor *, Tensor *> *input_isolate_map) {
|
||||
std::shared_ptr<abstract::Context> GetContext() override { return context_; }
|
||||
|
||||
void SetContext(std::shared_ptr<abstract::Context> context) override { context_ = context; }
|
||||
|
||||
const abstract::KernelCallBack &GetKernelBeforeCallBack() override { return before_; }
|
||||
|
||||
void SetKernelBeforeCallBack(const abstract::KernelCallBack &callback) override { before_ = callback; }
|
||||
|
||||
const abstract::KernelCallBack &GetKernelAfterCallBack() override { return after_; }
|
||||
|
||||
void SetKernelAfterCallBack(const abstract::KernelCallBack &callback) override { after_ = callback; }
|
||||
|
||||
void SetInputsMap(std::unordered_map<abstract::Tensor *, abstract::Tensor *> *input_isolate_map) {
|
||||
input_isolate_map_ = input_isolate_map;
|
||||
}
|
||||
|
||||
std::unordered_map<Tensor *, Tensor *> *GetInputMap() { return input_isolate_map_; }
|
||||
std::unordered_map<abstract::Tensor *, abstract::Tensor *> *GetInputMap() { return input_isolate_map_; }
|
||||
|
||||
void SetOutputsMap(std::unordered_map<Tensor *, Tensor *> *output_isolate_map) {
|
||||
void SetOutputsMap(std::unordered_map<abstract::Tensor *, abstract::Tensor *> *output_isolate_map) {
|
||||
output_isolate_map_ = output_isolate_map;
|
||||
}
|
||||
|
||||
std::unordered_map<Tensor *, Tensor *> *GetOutputMap() { return output_isolate_map_; }
|
||||
std::unordered_map<abstract::Tensor *, abstract::Tensor *> *GetOutputMap() { return output_isolate_map_; }
|
||||
|
||||
std::vector<abstract::Kernel *> ToKernelList() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<abstract::ExecutionFlow>> execution_flows_;
|
||||
FuncGraphPtr func_graph_;
|
||||
std::vector<abstract::Tensor *> inputs_;
|
||||
std::vector<abstract::Tensor *> outputs_;
|
||||
std::unordered_map<Tensor *, Tensor *> *input_isolate_map_ = nullptr;
|
||||
std::unordered_map<Tensor *, Tensor *> *output_isolate_map_ = nullptr;
|
||||
std::shared_ptr<abstract::Context> context_;
|
||||
abstract::KernelCallBack before_;
|
||||
abstract::KernelCallBack after_;
|
||||
std::unordered_map<abstract::Tensor *, abstract::Tensor *> *input_isolate_map_ = nullptr;
|
||||
std::unordered_map<abstract::Tensor *, abstract::Tensor *> *output_isolate_map_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::infer
|
||||
|
||||
|
|
|
@ -18,7 +18,9 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "extendrt/graph_compiler/default_graph_compiler.h"
|
||||
#include "extendrt/graph_compiler/factory.h"
|
||||
|
||||
#include "extendrt/mock/lite_runtime/converters.h"
|
||||
#include "backend/graph_compiler/graph_partition.h"
|
||||
#include "backend/graph_compiler/segment_runner.h"
|
||||
#include "common/log.h"
|
||||
|
@ -32,6 +34,8 @@ static constexpr auto ms_infer_backend_name = "mindspore_lite_backend";
|
|||
std::shared_ptr<abstract::ExecutionPlan> DefaultGraphCompiler::Compile(FuncGraphPtr graph) {
|
||||
MS_LOG(INFO) << "DefaultGraphCompiler::Compile";
|
||||
|
||||
inner_context_ = ContextUtils::Convert(context_.get());
|
||||
|
||||
MS_LOG(DEBUG) << "DefaultGraphCompiler::Partition Partition FunctionGraph Begin";
|
||||
auto graph_segments = Partition(graph);
|
||||
if (graph_segments.empty()) {
|
||||
|
@ -102,6 +106,7 @@ std::shared_ptr<abstract::ExecutionPlan> DefaultGraphCompiler::Schedule(
|
|||
return nullptr;
|
||||
}
|
||||
execution_plan->SetOutputs(graph_output_tensor);
|
||||
execution_plan->SetContext(inner_context_);
|
||||
|
||||
for (auto graph_segment : graph_segments) {
|
||||
FuncGraphPtr fg = nullptr;
|
||||
|
@ -115,6 +120,7 @@ std::shared_ptr<abstract::ExecutionPlan> DefaultGraphCompiler::Schedule(
|
|||
delete output_isolate_map;
|
||||
return nullptr;
|
||||
}
|
||||
execution_flow->SetContext(inner_context_);
|
||||
|
||||
for (auto i = 0; i < execution_flow->GetInputs().size(); i++) {
|
||||
auto input_tensor = execution_flow->GetInputs()[i];
|
||||
|
@ -231,4 +237,10 @@ std::shared_ptr<abstract::ExecutionFlow> DefaultGraphCompiler::Schedule(const Gr
|
|||
// implementation by hangangqiang
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static std::shared_ptr<InferSession> DefaultGraphCompilerCreator(const std::shared_ptr<Context> &ctx) {
|
||||
auto graph_compiler = std::make_shared<DefaultGraphCompiler>(ctx);
|
||||
return graph_compiler;
|
||||
}
|
||||
REG_GRAPH_COMPILER(kDefaultCompiler, DefaultGraphCompilerCreator);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,11 +20,14 @@
|
|||
#include <vector>
|
||||
|
||||
#include "infer/graph_compiler.h"
|
||||
#include "infer/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
class DefaultGraphCompiler : public mindspore::infer::abstract::GraphCompiler {
|
||||
public:
|
||||
DefaultGraphCompiler() {}
|
||||
explicit DefaultGraphCompiler(const std::shared_ptr<Context> &context) : context_(context) {
|
||||
inner_context_ = nullptr;
|
||||
}
|
||||
virtual ~DefaultGraphCompiler() = default;
|
||||
|
||||
std::shared_ptr<abstract::ExecutionPlan> Compile(FuncGraphPtr graph) override;
|
||||
|
@ -48,6 +51,8 @@ class DefaultGraphCompiler : public mindspore::infer::abstract::GraphCompiler {
|
|||
|
||||
private:
|
||||
mindspore::HashMap<AnfNodePtr, infer::abstract::Tensor *> anf_tensor_map_;
|
||||
const std::shared_ptr<Context> &context_;
|
||||
std::shared_ptr<mindspore::infer::abstract::Context> inner_context_;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -27,11 +27,12 @@ void GraphCompilerRegistry::RegCompiler(const mindspore::GraphCompilerType &type
|
|||
graph_compiler_map_[type] = creator;
|
||||
}
|
||||
|
||||
std::shared_ptr<infer::GraphCompiler> GraphCompilerRegistry::GetCompiler(const mindspore::GraphCompilerType &type) {
|
||||
std::shared_ptr<infer::abstract::GraphCompiler> GraphCompilerRegistry::GetCompiler(
|
||||
const mindspore::GraphCompilerType &type, const std::shared_ptr<Context> &context) {
|
||||
auto it = graph_compiler_map_.find(type);
|
||||
if (it == graph_compiler_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second();
|
||||
return it->second(context);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,11 +20,12 @@
|
|||
#include <memory>
|
||||
|
||||
#include "extendrt/graph_compiler/type.h"
|
||||
#include "include/api/context.h"
|
||||
#include "infer/graph_compiler.h"
|
||||
|
||||
namespace mindspore {
|
||||
using GraphCompiler = infer::abstract::GraphCompiler;
|
||||
using GraphCompilerRegFunc = std::function<std::shared_ptr<GraphCompiler>()>;
|
||||
using GraphCompilerRegFunc =
|
||||
std::function<std::shared_ptr<infer::abstract::GraphCompiler>(const std::shared_ptr<Context> &)>;
|
||||
|
||||
class GraphCompilerRegistry {
|
||||
public:
|
||||
|
@ -35,7 +36,8 @@ class GraphCompilerRegistry {
|
|||
|
||||
void RegCompiler(const mindspore::GraphCompilerType &graph_compiler_type, const GraphCompilerRegFunc &creator);
|
||||
|
||||
std::shared_ptr<GraphCompiler> GetCompiler(const mindspore::GraphCompilerType &type);
|
||||
std::shared_ptr<infer::abstract::GraphCompiler> GetCompiler(const mindspore::GraphCompilerType &type,
|
||||
const std::shared_ptr<Context> &context);
|
||||
|
||||
private:
|
||||
mindspore::HashMap<mindspore::GraphCompilerType, GraphCompilerRegFunc> graph_compiler_map_;
|
||||
|
@ -44,7 +46,7 @@ class GraphCompilerRegistry {
|
|||
class GraphCompilerRegistrar {
|
||||
public:
|
||||
GraphCompilerRegistrar(const mindspore::GraphCompilerType &graph_compiler_type, const GraphCompilerRegFunc &creator) {
|
||||
GraphCompilerRegistry::GetInstance().GetGraphCompiler(graph_compiler_type, creator);
|
||||
GraphCompilerRegistry::GetInstance().RegCompiler(graph_compiler_type, creator);
|
||||
}
|
||||
~GraphCompilerRegistrar() = default;
|
||||
};
|
||||
|
|
|
@ -16,9 +16,6 @@
|
|||
#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMIPLER_TYPE_H_
|
||||
#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMIPLER_TYPE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
enum GraphCompilerType { kDefaultCompiler = 0, kSingleOpSession, kLiteInferSession, kDelegateSession, kNoneCompiler };
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "extendrt/graph_executor/factory.h"
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
GraphExecutorRegistry &GraphExecutorRegistry::GetInstance() {
|
||||
static GraphExecutorRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void GraphExecutorRegistry::RegExecutor(const mindspore::GraphExecutorType &type, const GraphExecutorRegFunc &creator) {
|
||||
graph_executor_map_[type] = creator;
|
||||
}
|
||||
|
||||
std::shared_ptr<infer::GraphExecutor> GraphExecutorRegistry::GetExecutor(const mindspore::GraphExecutorType &type) {
|
||||
auto it = graph_executor_map_.find(type);
|
||||
if (it == graph_executor_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second();
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_EXECUTOR_FACTORY_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_EXECUTOR_FACTORY_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "extendrt/graph_executor/type.h"
|
||||
#include "infer/executor.h"
|
||||
#include "infer/execution_plan.h"
|
||||
|
||||
namespace mindspore {
|
||||
using GraphExecutorRegFunc = std::function<std::shared_ptr<infer::abstract::Executor>(
|
||||
const std::string &name, std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan)>;
|
||||
|
||||
class GraphExecutorRegistry {
|
||||
public:
|
||||
GraphExecutorRegistry() = default;
|
||||
virtual ~GraphExecutorRegistry() = default;
|
||||
|
||||
static GraphExecutorRegistry &GetInstance();
|
||||
|
||||
void RegExecutor(const GraphExecutorType &type, const GraphExecutorRegFunc &creator);
|
||||
|
||||
std::shared_ptr<infer::abstract::Executor> GetExecutor(
|
||||
const mindspore::GraphExecutorType &type, const std::string &name,
|
||||
std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan);
|
||||
|
||||
private:
|
||||
mindspore::HashMap<GraphExecutorType, GraphExecutorRegFunc> graph_executor_map_;
|
||||
};
|
||||
|
||||
class GraphExecutorRegistrar {
|
||||
public:
|
||||
GraphExecutorRegistrar(const mindspore::GraphExecutorType &type, const GraphExecutorRegFunc &creator) {
|
||||
GraphExecutorRegistry::GetInstance().RegExecutor(type, creator);
|
||||
}
|
||||
~GraphExecutorRegistrar() = default;
|
||||
};
|
||||
|
||||
#define REG_GRAPH_EXECUTOR(type, creator) static GraphExecutorRegistrar g_##type##GraphExecutor(type, creator);
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_EXECUTOR_FACTORY_H_
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "extendrt/flow_executor.h"
|
||||
#include "extendrt/graph_executor/flow_executor.h"
|
||||
#include "extendrt/execution_plan.h"
|
||||
#include "litert/mindrt_executor.h"
|
||||
|
|
@ -35,7 +35,7 @@ class FlowExecutor : public mindspore::infer::abstract::Executor {
|
|||
|
||||
const std::string &Name() override { return name_; }
|
||||
|
||||
Status Prepare(std::shared_ptr<abstract::ExecutionFlow> execution_flow) override;
|
||||
Status Prepare() override;
|
||||
|
||||
Status Execute() override;
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "extendrt/graph_executor/mindrt_graph_executor.h"
|
||||
|
||||
#include "src/common/log.h"
|
||||
#include "extendrt/graph_executor/factory.h"
|
||||
#include "litert/mindrt_executor.h"
|
||||
#include "extendrt/execution_plan.h"
|
||||
|
||||
namespace mindspore {
|
||||
MindRTGraphExecutor::MindRTGraphExecutor() {
|
||||
name_ = "";
|
||||
execution_plan_ = nullptr;
|
||||
}
|
||||
|
||||
MindRTGraphExecutor::MindRTGraphExecutor(const std::string &name,
|
||||
std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan) {
|
||||
name_ = name;
|
||||
execution_plan_ = execution_plan;
|
||||
auto infer_execution_plan = std::dynamic_pointer_cast<infer::ExecutionPlan>(execution_plan_);
|
||||
if (infer_execution_plan == nullptr) {
|
||||
MS_LOG(ERROR) << "MindRTGraphExecutor::MindRTGraphExecutor Not Supported execution plan is passed";
|
||||
} else {
|
||||
mindrt_executor_ = std::make_shared<mindspore::lite::MindrtExecutor>(infer_execution_plan->GetInputMap(),
|
||||
infer_execution_plan->GetOutputMap());
|
||||
}
|
||||
}
|
||||
|
||||
Status MindRTGraphExecutor::Prepare() {
|
||||
if (mindrt_executor_ == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::Prepare executor is nullptr";
|
||||
return kLiteError;
|
||||
}
|
||||
if (execution_plan_ == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::Prepare execution plan is nullptr";
|
||||
return kLiteError;
|
||||
}
|
||||
return mindrt_executor_->Prepare(execution_plan_->ToKernelList(), execution_plan_->GetInputs(),
|
||||
execution_plan_->GetOutputs(), execution_plan_->GetContext().get());
|
||||
}
|
||||
|
||||
Status MindRTGraphExecutor::Execute() {
|
||||
if (mindrt_executor_ == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::Execute executor is nullptr";
|
||||
return kLiteError;
|
||||
}
|
||||
if (execution_plan_ == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::Execute execution plan is nullptr";
|
||||
return kLiteError;
|
||||
}
|
||||
return mindrt_executor_->Run(execution_plan_->GetInputs(), execution_plan_->GetOutputs(),
|
||||
execution_plan_->ToKernelList(), execution_plan_->GetKernelBeforeCallBack(),
|
||||
execution_plan_->GetKernelAfterCallBack());
|
||||
}
|
||||
|
||||
int MindRTGraphExecutor::Resize(const std::vector<infer::abstract::Tensor *> &inputs,
|
||||
const std::vector<std::vector<int>> &dims) {
|
||||
if (mindrt_executor_ == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::Resize executor is nullptr";
|
||||
return kLiteError;
|
||||
}
|
||||
return mindrt_executor_->Resize(inputs, dims);
|
||||
}
|
||||
|
||||
static std::shared_ptr<infer::abstract::Executor> MindRTGraphExecutorCreator(
|
||||
const std::string &name, std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan) {
|
||||
auto graph_executor = std::make_shared<MindRTGraphExecutor>(name, execution_plan);
|
||||
return graph_executor;
|
||||
}
|
||||
REG_GRAPH_EXECUTOR(kMindRTExecutor, MindRTGraphExecutorCreator);
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "infer/executor.h"
|
||||
#include "infer/execution_plan.h"
|
||||
#include "litert/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MindRTGraphExecutor : public mindspore::infer::abstract::Executor {
|
||||
public:
|
||||
MindRTGraphExecutor();
|
||||
explicit MindRTGraphExecutor(const std::string &name, std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan);
|
||||
virtual ~MindRTGraphExecutor() = default;
|
||||
|
||||
const std::string &Name() override { return name_; }
|
||||
|
||||
Status Prepare() override;
|
||||
|
||||
Status Execute() override;
|
||||
|
||||
int Resize(const std::vector<infer::abstract::Tensor *> &inputs, const std::vector<std::vector<int>> &dims) override;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::shared_ptr<mindspore::lite::Executor> mindrt_executor_;
|
||||
std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "extendrt/graph_executor/plan_executor.h"
|
||||
#include "extendrt/execution_plan.h"
|
||||
#include "litert/mindrt_executor.h"
|
||||
|
||||
namespace mindspore::infer {
|
||||
PlanExecutor::PlanExecutor() { PlanExecutor("PlanExecutor"); }
|
||||
|
||||
PlanExecutor::PlanExecutor(const std::string &name, std::shared_ptr<abstract::ExecutionPlan> execution_plan) {
|
||||
name_ = name;
|
||||
execution_plan_ = execution_plan;
|
||||
auto infer_execution_plan = std::dynamic_pointer_cast<infer::ExecutionPlan>(execution_plan_);
|
||||
if (infer_execution_plan == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::FlowExecutor Not Supported execution plan is passed";
|
||||
} else {
|
||||
executor_ = std::make_shared<mindspore::lite::MindrtExecutor>(infer_execution_plan->GetInputMap(),
|
||||
infer_execution_plan->GetOutputMap());
|
||||
}
|
||||
}
|
||||
|
||||
Status PlanExecutor::Prepare() {
|
||||
if (executor_ == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::Prepare executor is nullptr";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
if (execution_plan_ == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::Prepare execution plan is nullptr";
|
||||
return kLiteError;
|
||||
}
|
||||
return executor_->Prepare(execution_plan_->ToKernelList(), execution_plan_->GetInputs(),
|
||||
execution_plan_->GetOutputs(), execution_plan_->GetContext());
|
||||
}
|
||||
|
||||
Status PlanExecutor::Execute() {
|
||||
if (executor_ == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::Execute executor is nullptr";
|
||||
return kLiteError;
|
||||
}
|
||||
if (execution_plan_ == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::Execute execution plan is nullptr";
|
||||
return kLiteError;
|
||||
}
|
||||
return executor_->Run(execution_plan_->GetInputs(), execution_plan_->GetOutputs(), execution_plan_->ToKernelList(),
|
||||
execution_plan_->GetKernelBeforeCallBack(), execution_plan_->GetKernelAfterCallBack());
|
||||
}
|
||||
|
||||
int PlanExecutor::Resize(const std::vector<abstract::Tensor *> &inputs, const std::vector<std::vector<int>> &dims) {
|
||||
if (executor_ == nullptr) {
|
||||
MS_LOG(ERROR) << "FlowExecutor::Resize executor is nullptr";
|
||||
return kLiteError;
|
||||
}
|
||||
return executor_->Resize(inputs, dims);
|
||||
}
|
||||
} // namespace mindspore::infer
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_PLAN_EXECUTOR_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_PLAN_EXECUTOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "infer/executor.h"
|
||||
#include "infer/execution_plan.h"
|
||||
#include "litert/executor.h"
|
||||
|
||||
namespace mindspore::infer {
|
||||
class PlanExecutor : public mindspore::infer::abstract::Executor {
|
||||
public:
|
||||
PlanExecutor();
|
||||
// explicit FlowExecutor(const std::string &name);
|
||||
explicit PlanExecutor(const std::string &name);
|
||||
virtual ~PlanExecutor() = default;
|
||||
|
||||
const std::string &Name() override { return name_; }
|
||||
|
||||
Status Prepare() override;
|
||||
|
||||
Status Execute() override;
|
||||
|
||||
int Resize(const std::vector<abstract::Tensor *> &inputs, const std::vector<std::vector<int>> &dims) override;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::shared_ptr<mindspore::lite::Executor> executor_;
|
||||
std::shared_ptr<abstract::ExecutionPlan> execution_plan_;
|
||||
};
|
||||
} // namespace mindspore::infer
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_PLAN_EXECUTOR_H_
|
|
@ -0,0 +1,22 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_EXECUTOR_TYPE_H_
|
||||
#define MINDSPORE_LITE_EXTENDRT_GRAPH_EXECUTOR_TYPE_H_
|
||||
|
||||
namespace mindspore {
|
||||
enum GraphExecutorType { kDefaultExecutor = 0, kMindRTExecutor, kNoneExcutor };
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_EXTENDRT_GRAPH_EXECUTOR_TYPE_H_
|
|
@ -15,8 +15,9 @@
|
|||
*/
|
||||
#include "extendrt/graph_runtime/default_graph_runtime.h"
|
||||
|
||||
#include "extendrt/flow_executor.h"
|
||||
#include "extendrt/graph_executor/plan_executor.h"
|
||||
#include "src/common/log.h"
|
||||
#include "extendrt/graph_runtime/factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
using ExecutionPlan = mindspore::infer::abstract::ExecutionPlan;
|
||||
|
@ -30,20 +31,20 @@ Status DefaultGraphRuntime::Prepare(std::shared_ptr<ExecutionPlan> execution_pla
|
|||
}
|
||||
execution_plan_ = execution_plan;
|
||||
|
||||
for (auto execution_flow : execution_plan->GetExecutionFLows()) {
|
||||
auto executor = SelectExecutor(execution_flow);
|
||||
if (executor == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Prepare Select Executor is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Prepare Prepare Execution Plan Begin of Executor " << executor->Name();
|
||||
auto status = executor->Prepare(execution_flow);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Prepare Prepare Execution Plan Failed in Executor " << executor->Name();
|
||||
return kLiteError;
|
||||
}
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Prepare Prepare Execution Plan End";
|
||||
auto executor = SelectExecutor();
|
||||
if (executor == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Prepare Select Executor is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Prepare Prepare Execution Plan Begin of Executor " << executor->Name();
|
||||
auto status = executor->Prepare(nullptr);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Prepare Prepare Execution Plan Failed in Executor " << executor->Name();
|
||||
return kLiteError;
|
||||
}
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Prepare Prepare Execution Plan End";
|
||||
|
||||
MS_LOG(INFO) << "AbstractRuntime::Prepare End";
|
||||
return kSuccess;
|
||||
}
|
||||
|
@ -56,27 +57,27 @@ Status DefaultGraphRuntime::Execute() {
|
|||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
for (auto execution_flow : execution_plan_->GetExecutionFLows()) {
|
||||
auto executor = SelectExecutor(execution_flow);
|
||||
if (executor == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Select Executor is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execution Plan Begin of Executor " << executor->Name();
|
||||
auto status = executor->Execute();
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Execution Plan Failed in Executor " << executor->Name();
|
||||
return kLiteError;
|
||||
}
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Prepare Execution Plan End";
|
||||
auto executor = SelectExecutor();
|
||||
if (executor == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Select Executor is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execute Execution Plan Begin of Executor " << executor->Name();
|
||||
auto status = executor->Execute();
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Execute Execution Plan Failed in Executor " << executor->Name();
|
||||
return kLiteError;
|
||||
}
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execute Execution Plan End";
|
||||
|
||||
MS_LOG(INFO) << "DefaultGraphRuntime::Execute End";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status DefaultGraphRuntime::Execute(const std::vector<abstract::Tensor *> &inputs,
|
||||
const std::vector<abstract::Tensor *> &outputs, abstract::KernelCallBack before,
|
||||
abstract::KernelCallBack after) {
|
||||
Status DefaultGraphRuntime::Execute(const std::vector<infer::abstract::Tensor *> &inputs,
|
||||
const std::vector<infer::abstract::Tensor *> &outputs,
|
||||
infer::abstract::KernelCallBack before, infer::abstract::KernelCallBack after) {
|
||||
MS_LOG(INFO) << "DefaultGraphRuntime::Execute Begin";
|
||||
|
||||
if (execution_plan_ == nullptr) {
|
||||
|
@ -84,37 +85,65 @@ Status DefaultGraphRuntime::Execute(const std::vector<abstract::Tensor *> &input
|
|||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
for (auto &execution_flow : execution_plan_->GetExecutionFLows()) {
|
||||
auto executor = SelectExecutor(execution_flow);
|
||||
if (executor == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Select Executor is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execution Plan Begin of Executor " << executor->Name();
|
||||
execution_flow->SetInputs(inputs);
|
||||
execution_flow->SetOutputs(outputs);
|
||||
execution_flow->SetKernelBeforeCallBack(before);
|
||||
execution_flow->SetKernelAfterCallBack(after);
|
||||
auto status = executor->Execute();
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Execution Plan Failed in Executor " << executor->Name();
|
||||
return kLiteError;
|
||||
}
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Prepare Execution Plan End";
|
||||
auto executor = SelectExecutor();
|
||||
if (executor == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Select Executor is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execute Execution Plan Begin of Executor " << executor->Name();
|
||||
execution_plan_->SetInputs(inputs);
|
||||
execution_plan_->SetOutputs(outputs);
|
||||
execution_plan_->SetKernelBeforeCallBack(before);
|
||||
execution_plan_->SetKernelAfterCallBack(after);
|
||||
auto status = executor->Execute();
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Execute Execute Execution Plan Failed in Executor " << executor->Name();
|
||||
return kLiteError;
|
||||
}
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Execute Execute Execution Plan End";
|
||||
|
||||
MS_LOG(INFO) << "DefaultGraphRuntime::Execute End";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
std::shared_ptr<abstract::Executor> DefaultGraphRuntime::SelectExecutor(
|
||||
const std::shared_ptr<abstract::ExecutionFlow> &execution_flow) {
|
||||
auto it = executor_map_.find(execution_flow);
|
||||
if (it == executor_map_.end()) {
|
||||
// create a new executor for execution flow
|
||||
auto executor = std::make_shared<infer::FlowExecutor>("flow-executor");
|
||||
executor_map_[execution_flow] = executor;
|
||||
return executor;
|
||||
Status DefaultGraphRuntime::Resize(const std::vector<infer::abstract::Tensor *> *inputs,
|
||||
const std::vector<std::vector<int64_t>> &dims) {
|
||||
MS_LOG(INFO) << "DefaultGraphRuntime::Resize Begin";
|
||||
|
||||
if (execution_plan_ == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Resize Execution Plan is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
return it->second;
|
||||
|
||||
auto executor = SelectExecutor();
|
||||
if (executor == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Resize Select Executor is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Resize Resize Execution Plan Begin of Executor " << executor->Name();
|
||||
auto status = executor->Resize(inputs, dims);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "DefaultGraphRuntime::Resize Resize Execution Plan Failed in Executor " << executor->Name();
|
||||
return kLiteError;
|
||||
}
|
||||
MS_LOG(DEBUG) << "DefaultGraphRuntime::Resize Resize Execution Plan End";
|
||||
|
||||
MS_LOG(INFO) << "DefaultGraphRuntime::Resize End";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
std::shared_ptr<infer::abstract::Executor> DefaultGraphRuntime::SelectExecutor() {
|
||||
if (default_executor_ == nullptr) {
|
||||
default_executor_ = std::make_shared<infer::PlanExecutor>("plan-executor");
|
||||
}
|
||||
return default_executor_;
|
||||
}
|
||||
|
||||
static std::shared_ptr<InferSession> DefaultGraphRuntimeCreator() {
|
||||
auto graph_runtime = std::make_shared<DefaultGraphRuntime>();
|
||||
return graph_runtime;
|
||||
}
|
||||
REG_GRAPH_RUNTIME(kDefaultRuntime, DefaultGraphRuntimeCreator);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,19 +27,24 @@ class DefaultGraphRuntime : public mindspore::infer::abstract::GraphRuntime {
|
|||
DefaultGraphRuntime() = default;
|
||||
virtual ~DefaultGraphRuntime() = default;
|
||||
|
||||
Status Prepare(std::shared_ptr<abstract::ExecutionPlan> execution_plan) override;
|
||||
Status Prepare(std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan) override;
|
||||
|
||||
Status Execute() override;
|
||||
|
||||
Status Execute(const std::vector<abstract::Tensor *> &inputs, const std::vector<abstract::Tensor *> &outputs,
|
||||
abstract::KernelCallBack before = nullptr, abstract::KernelCallBack after = nullptr) override;
|
||||
Status Execute(const std::vector<infer::abstract::Tensor *> &inputs,
|
||||
const std::vector<infer::abstract::Tensor *> &outputs,
|
||||
infer::abstract::KernelCallBack before = nullptr,
|
||||
infer::abstract::KernelCallBack after = nullptr) override;
|
||||
|
||||
Status Resize(const std::vector<infer::abstract::Tensor *> *inputs,
|
||||
const std::vector<std::vector<int64_t>> &dims) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<abstract::Executor> SelectExecutor(const std::shared_ptr<abstract::ExecutionFlow> &execution_flow);
|
||||
std::shared_ptr<infer::abstract::Executor> SelectExecutor();
|
||||
|
||||
private:
|
||||
std::shared_ptr<abstract::ExecutionPlan> execution_plan_ = nullptr;
|
||||
mindspore::HashMap<std::shared_ptr<abstract::ExecutionFlow>, std::shared_ptr<abstract::Executor>> executor_map_;
|
||||
std::shared_ptr<infer::abstract::ExecutionPlan> execution_plan_ = nullptr;
|
||||
std::shared_ptr<infer::abstract::Executor> default_executor_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -18,16 +18,17 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
GraphRuntimRegistry &GraphRuntimRegistry::GetInstance() {
|
||||
static GraphRuntimRegistry instance;
|
||||
GraphRuntimeRegistry &GraphRuntimeRegistry::GetInstance() {
|
||||
static GraphRuntimeRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void GraphRuntimRegistry::RegRuntime(const mindspore::GraphRuntimeType &type, const GraphRuntimeRegFunc &creator) {
|
||||
void GraphRuntimeRegistry::RegRuntime(const mindspore::GraphRuntimeType &type, const GraphRuntimeRegFunc &creator) {
|
||||
graph_runtime_map_[type] = creator;
|
||||
}
|
||||
|
||||
std::shared_ptr<infer::GraphRuntime> GraphRuntimRegistry::GetRuntime(const mindspore::GraphRuntimeType &type) {
|
||||
std::shared_ptr<infer::abstract::GraphRuntime> GraphRuntimeRegistry::GetRuntime(
|
||||
const mindspore::GraphRuntimeType &type) {
|
||||
auto it = graph_runtime_map_.find(type);
|
||||
if (it == graph_runtime_map_.end()) {
|
||||
return nullptr;
|
||||
|
|
|
@ -23,19 +23,18 @@
|
|||
#include "infer/graph_runtime.h"
|
||||
|
||||
namespace mindspore {
|
||||
using GraphRuntime = infer::abstract::GraphRuntime;
|
||||
using GraphRuntimeRegFunc = std::function<std::shared_ptr<GraphRuntime>()>;
|
||||
using GraphRuntimeRegFunc = std::function<std::shared_ptr<infer::abstract::GraphRuntime>()>;
|
||||
|
||||
class GraphRuntimRegistry {
|
||||
class GraphRuntimeRegistry {
|
||||
public:
|
||||
GraphRuntimRegistry() = default;
|
||||
virtual ~GraphRuntimRegistry() = default;
|
||||
GraphRuntimeRegistry() = default;
|
||||
virtual ~GraphRuntimeRegistry() = default;
|
||||
|
||||
static GraphRuntimRegistry &GetInstance();
|
||||
static GraphRuntimeRegistry &GetInstance();
|
||||
|
||||
void RegRuntime(const GraphRuntimeType &type, const GraphRuntimeRegFunc &creator);
|
||||
|
||||
std::shared_ptr<GraphRuntime> GetRuntime(const mindspore::GraphRuntimeType &type);
|
||||
std::shared_ptr<infer::abstract::GraphRuntime> GetRuntime(const mindspore::GraphRuntimeType &type);
|
||||
|
||||
private:
|
||||
mindspore::HashMap<GraphRuntimeType, GraphRuntimeRegFunc> graph_runtime_map_;
|
||||
|
@ -44,7 +43,7 @@ class GraphRuntimRegistry {
|
|||
class GraphRuntimeRegistrar {
|
||||
public:
|
||||
GraphRuntimeRegistrar(const mindspore::GraphRuntimeType &type, const GraphRuntimeRegFunc &creator) {
|
||||
GraphRuntimRegistry::GetInstance().RegRuntime(type, creator);
|
||||
GraphRuntimeRegistry::GetInstance().RegRuntime(type, creator);
|
||||
}
|
||||
~GraphRuntimeRegistrar() = default;
|
||||
};
|
||||
|
|
|
@ -16,9 +16,6 @@
|
|||
#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_RUNTIME_TYPE_H_
|
||||
#define MINDSPORE_LITE_EXTENDRT_GRAPH_RUNTIME_TYPE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
enum GraphRuntimeType { kDefaultRuntime = 0, kSingleOpSession, kLiteInferSession, kDelegateSession, kNoneRuntime };
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "extendrt/session/factory.h"
|
||||
#include "extendrt/graph_compiler/factory.h"
|
||||
#include "extendrt/graph_runtime/factory.h"
|
||||
#include "extendrt/utils/tensor_utils.h"
|
||||
#include "backend/graph_compiler/graph_partition.h"
|
||||
|
||||
#include "litert/cxx_api/tensor/tensor_impl.h"
|
||||
|
@ -32,18 +33,18 @@ static const std::vector<PrimitivePtr> ms_infer_cut_list = {prim::kPrimReturn,
|
|||
prim::kPrimBpropCut, prim::kPrimSwitchLayer};
|
||||
Status DefaultInferSession::Init(const std::shared_ptr<Context> &context) {
|
||||
MS_LOG(INFO) << "DefaultInferSession::Init";
|
||||
context_ = context;
|
||||
// context_ = context;
|
||||
|
||||
// Set MSContext::GetInstance param?
|
||||
|
||||
// init compiler and runtime according to context
|
||||
compiler_ = GraphCompilerRegistry::GetInstance()->GetCompiler(kDefaultCompiler);
|
||||
compiler_ = GraphCompilerRegistry::GetInstance().GetCompiler(kDefaultCompiler, context_);
|
||||
if (compiler_ == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultInferSession::Init Get Compiler is nullptr";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
runtime_ = GraphRuntimRegistry::GetInstance()->GetRuntime(kDefaultRuntime);
|
||||
runtime_ = GraphRuntimeRegistry::GetInstance().GetRuntime(kDefaultRuntime);
|
||||
if (runtime_ == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultInferSession::Init Get Runtime is nullptr";
|
||||
return kLiteNullptr;
|
||||
|
@ -68,7 +69,7 @@ Status DefaultInferSession::CompileGraph(FuncGraphPtr graph, const void *data, s
|
|||
MS_LOG(DEBUG) << "DefaultInferSession::CompileGraph Compile Graph End";
|
||||
|
||||
MS_LOG(DEBUG) << "DefaultInferSession::CompileGraph Prepare ExecutionPlan Begin";
|
||||
auto runtime = this->GetRuntime();
|
||||
auto runtime = this->GetGraphRuntime();
|
||||
if (runtime == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultInferSession::CompileGraph Runtime in Infer Session is null";
|
||||
return kLiteNullptr;
|
||||
|
@ -86,7 +87,7 @@ Status DefaultInferSession::CompileGraph(FuncGraphPtr graph, const void *data, s
|
|||
Status DefaultInferSession::RunGraph(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
MS_LOG(DEBUG) << "DefaultInferSession::RunGraph Execute ExecutionPlan Begin";
|
||||
auto runtime = this->GetRuntime();
|
||||
auto runtime = this->GetGraphRuntime();
|
||||
if (runtime_ == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultInferSession::RunGraph Runtime in Infer Session is null";
|
||||
return kLiteNullptr;
|
||||
|
@ -101,7 +102,7 @@ Status DefaultInferSession::RunGraph(const std::vector<tensor::Tensor> &inputs,
|
|||
MS_LOG(ERROR) << "DefaultInferSession::RunGraph Copy Data Pointer to input tensors failed";
|
||||
return status;
|
||||
}
|
||||
status = CopyDataToInnerTensors(outputs, inner_outputs);
|
||||
status = CopyDataToInnerTensors(*outputs, inner_outputs);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "DefaultInferSession::RunGraph Copy Data Pointer to output tensors failed";
|
||||
return status;
|
||||
|
@ -124,11 +125,16 @@ Status DefaultInferSession::RunGraph(const std::vector<tensor::Tensor> &inputs,
|
|||
return RunGraph(inputs, outputs, nullptr, nullptr);
|
||||
}
|
||||
|
||||
Status DefaultInferSession::Resize(const std::vector<tensor::Tensor> &inputs,
|
||||
const std::vector<std::vector<int64_t>> &dims) {
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
std::vector<MutableTensorImplPtr> DefaultInferSession::GetOutputs() {
|
||||
auto runtime = this->GetRuntime();
|
||||
auto runtime = this->GetGraphRuntime();
|
||||
if (runtime_ == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultInferSession::GetOutputs Runtime in Infer Session is null";
|
||||
return kLiteNullptr;
|
||||
return std::vector<MutableTensorImplPtr>{};
|
||||
}
|
||||
auto lite_outputs = runtime->GetOutputs();
|
||||
MS_LOG(DEBUG) << "DefaultInferSession::GetOutputs end";
|
||||
|
@ -136,10 +142,10 @@ std::vector<MutableTensorImplPtr> DefaultInferSession::GetOutputs() {
|
|||
}
|
||||
|
||||
std::vector<MutableTensorImplPtr> DefaultInferSession::GetInputs() {
|
||||
auto runtime = this->GetRuntime();
|
||||
auto runtime = this->GetGraphRuntime();
|
||||
if (runtime_ == nullptr) {
|
||||
MS_LOG(ERROR) << "DefaultInferSession::GetOutputs Runtime in Infer Session is null";
|
||||
return kLiteNullptr;
|
||||
return std::vector<MutableTensorImplPtr>{};
|
||||
}
|
||||
auto lite_inputs = runtime->GetInputs();
|
||||
MS_LOG(DEBUG) << "DefaultInferSession::GetOutputs end";
|
||||
|
@ -152,7 +158,7 @@ MutableTensorImplPtr DefaultInferSession::GetOutputByTensorName(const std::strin
|
|||
MutableTensorImplPtr DefaultInferSession::GetInputByTensorName(const std::string &name) { return nullptr; }
|
||||
|
||||
Status DefaultInferSession::CopyDataToInnerTensors(const std::vector<tensor::Tensor> &tensors,
|
||||
std::vector<abstract::Tensor *> inner_tensors) {
|
||||
std::vector<infer::abstract::Tensor *> inner_tensors) {
|
||||
if (tensors.size() == inner_tensors.size()) {
|
||||
MS_LOG(EXCEPTION) << "user input size " << tensors.size() << " is not equal to graphp input size "
|
||||
<< inner_tensors.size();
|
||||
|
@ -200,17 +206,17 @@ Status DefaultInferSession::CopyDataToInnerTensors(const std::vector<tensor::Ten
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
std::vector<MutableTensorImplPtr> &DefaultInferSession::AbstractTensorsToTensorImpls(
|
||||
const std::vector<abstract::Tensor *> &abstract_tensors) {
|
||||
std::vector<std::shared_ptr<LiteTensorImpl>> tensorImpls;
|
||||
std::vector<MutableTensorImplPtr> DefaultInferSession::AbstractTensorsToTensorImpls(
|
||||
const std::vector<infer::abstract::Tensor *> &abstract_tensors) {
|
||||
std::vector<MutableTensorImplPtr> tensorImpls;
|
||||
tensorImpls.reserve(abstract_tensors.size());
|
||||
(void)std::transform(abstract_tensors.begin(), abstract_tensors.end(), std::back_inserter(tensorImpls),
|
||||
[](abstract::Tensor *tensor) { return std::make_shared<LiteTensorImpl>(tensor); });
|
||||
[](infer::abstract::Tensor *tensor) { return std::make_shared<LiteTensorImpl>(tensor); });
|
||||
return tensorImpls;
|
||||
}
|
||||
|
||||
std::vector<mindspore::tensor::Tensor> DefaultInferSession::LiteTensorToTensor(
|
||||
const std::vector<abstract::Tensor *> &abstract_tensors) {
|
||||
const std::vector<infer::abstract::Tensor *> &abstract_tensors) {
|
||||
std::vector<mindspore::tensor::Tensor> tensors;
|
||||
for (auto abstract_tensor : abstract_tensors) {
|
||||
if (abstract_tensor == nullptr) {
|
||||
|
@ -222,11 +228,14 @@ std::vector<mindspore::tensor::Tensor> DefaultInferSession::LiteTensorToTensor(
|
|||
auto data = abstract_tensor->MutableData();
|
||||
auto data_size = abstract_tensor->Size();
|
||||
auto ref_tensor_data =
|
||||
std::make_shared<TensorRefData>(data, abstract_tensor->ElementNum(), data_size, shape.size());
|
||||
mindspore::tensor::Tensor tensor(type_id, shape, ref_tensor_data);
|
||||
std::make_shared<TensorRefData>(data, abstract_tensor->ElementsNum(), data_size, shape.size());
|
||||
std::vector<int64_t> shape64;
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(shape64),
|
||||
[](int dim) { return static_cast<int64_t>(dim); });
|
||||
mindspore::tensor::Tensor tensor(type_id, shape64, ref_tensor_data);
|
||||
auto device_address = abstract_tensor->device_data();
|
||||
if (device_address != nullptr) {
|
||||
auto lite_device_address = std::make_shared<LiteDeviceAddress>(device_address, abstract_tensor->DataSize());
|
||||
auto lite_device_address = std::make_shared<LiteDeviceAddress>(device_address, abstract_tensor->Size());
|
||||
tensor.set_device_address(lite_device_address);
|
||||
}
|
||||
tensors.emplace_back(std::move(tensor));
|
||||
|
@ -234,6 +243,34 @@ std::vector<mindspore::tensor::Tensor> DefaultInferSession::LiteTensorToTensor(
|
|||
return tensors;
|
||||
}
|
||||
|
||||
std::vector<int32_t> DefaultInferSession::TruncateShape(const std::vector<int64_t> &shape, enum TypeId type,
|
||||
size_t data_len, bool verify_size) {
|
||||
std::vector<int32_t> empty;
|
||||
if (shape.empty()) {
|
||||
return empty;
|
||||
}
|
||||
std::vector<int32_t> truncated_shape;
|
||||
truncated_shape.resize(shape.size());
|
||||
size_t element_size = lite::DataTypeSize(type);
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
auto dim = shape[i];
|
||||
if (dim < 0 || dim > INT_MAX || (dim != 0 && element_size > INT_MAX / static_cast<size_t>(dim))) {
|
||||
MS_LOG(ERROR) << "Invalid shape!dim: " << dim << ", element_size: " << element_size;
|
||||
return empty;
|
||||
} else {
|
||||
element_size *= static_cast<size_t>(dim);
|
||||
truncated_shape[i] = static_cast<int32_t>(dim);
|
||||
}
|
||||
}
|
||||
if (verify_size) {
|
||||
if (element_size != data_len) {
|
||||
MS_LOG(ERROR) << "Invalid data size!element_size: " << element_size << ", data_len: " << data_len;
|
||||
return empty;
|
||||
}
|
||||
}
|
||||
return truncated_shape;
|
||||
}
|
||||
|
||||
static std::shared_ptr<InferSession> DefaultSessionCreator(const std::shared_ptr<Context> &ctx,
|
||||
const ConfigInfos &config_infos) {
|
||||
auto session = std::make_shared<DefaultInferSession>(ctx);
|
||||
|
|
|
@ -37,6 +37,7 @@ class DefaultInferSession : public InferSession {
|
|||
Status RunGraph(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs) override;
|
||||
Status RunGraph(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) override;
|
||||
Status Resize(const std::vector<tensor::Tensor> &inputs, const std::vector<std::vector<int64_t>> &dims) override;
|
||||
std::vector<MutableTensorImplPtr> GetOutputs() override;
|
||||
std::vector<MutableTensorImplPtr> GetInputs() override;
|
||||
std::vector<std::string> GetOutputNames() override;
|
||||
|
@ -45,16 +46,19 @@ class DefaultInferSession : public InferSession {
|
|||
MutableTensorImplPtr GetInputByTensorName(const std::string &name) override;
|
||||
|
||||
protected:
|
||||
virtual std::shared_ptr<infer::abstract::GraphCompiler> GetCompiler() { return compiler_; }
|
||||
virtual std::shared_ptr<infer::abstract::GraphCompiler> GetGraphCompiler() { return compiler_; }
|
||||
|
||||
virtual std::shared_ptr<infer::abstract::GraphRuntime> GetRuntime() { return runtime_; }
|
||||
virtual std::shared_ptr<infer::abstract::GraphRuntime> GetGraphRuntime() { return runtime_; }
|
||||
|
||||
private:
|
||||
Status CopyDataToInnerTensors(const std::vector<tensor::Tensor> &tensors,
|
||||
std::vector<abstract::Tensor *> inner_tensors);
|
||||
std::vector<MutableTensorImplPtr> &AbstractTensorsToTensorImpls(
|
||||
const std::vector<abstract::Tensor *> &abstract_tensors);
|
||||
std::vector<mindspore::tensor::Tensor> LiteTensorToTensor(const std::vector<abstract::Tensor *> &abstract_tensors);
|
||||
std::vector<infer::abstract::Tensor *> inner_tensors);
|
||||
std::vector<MutableTensorImplPtr> AbstractTensorsToTensorImpls(
|
||||
const std::vector<infer::abstract::Tensor *> &abstract_tensors);
|
||||
std::vector<mindspore::tensor::Tensor> LiteTensorToTensor(
|
||||
const std::vector<infer::abstract::Tensor *> &abstract_tensors);
|
||||
std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
|
||||
bool verify_size);
|
||||
|
||||
private:
|
||||
std::shared_ptr<infer::abstract::GraphCompiler> compiler_;
|
||||
|
|
|
@ -67,14 +67,14 @@ class ExecutionFlow : public std::enable_shared_from_this<ExecutionFlow> {
|
|||
/// \brief Get context for the execution flow.
|
||||
///
|
||||
/// \return Context pointer.
|
||||
virtual Context *GetContext() = 0;
|
||||
virtual std::shared_ptr<Context> GetContext() = 0;
|
||||
|
||||
/// \brief Set context of execution run
|
||||
///
|
||||
/// \param[in] context, context for running
|
||||
///
|
||||
/// \return void.
|
||||
virtual void SetContext(Context *context) = 0;
|
||||
virtual void SetContext(std::shared_ptr<Context> context) = 0;
|
||||
|
||||
/// \brief Get callback before kernel execution.
|
||||
///
|
||||
|
@ -99,6 +99,11 @@ class ExecutionFlow : public std::enable_shared_from_this<ExecutionFlow> {
|
|||
///
|
||||
/// \return void.
|
||||
virtual void SetKernelAfterCallBack(const KernelCallBack &callback) = 0;
|
||||
|
||||
/// \brief Construct flow into one fusion Kernel, eg. SubGraphKernel.
|
||||
///
|
||||
/// \return Kernel pointer.
|
||||
virtual Kernel *ConstructFusionKernel() = 0;
|
||||
};
|
||||
} // namespace mindspore::infer::abstract
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "ir/func_graph.h"
|
||||
#include "infer/execution_flow.h"
|
||||
#include "infer/context.h"
|
||||
|
||||
namespace mindspore::infer::abstract {
|
||||
class ExecutionPlan : public std::enable_shared_from_this<ExecutionPlan> {
|
||||
|
@ -81,6 +82,47 @@ class ExecutionPlan : public std::enable_shared_from_this<ExecutionPlan> {
|
|||
///
|
||||
/// \return void.
|
||||
virtual void SetOutputs(const std::vector<Tensor *> &outputs) = 0;
|
||||
|
||||
/// \brief Get context of execution plan.
|
||||
///
|
||||
/// \return Context of execution plan.
|
||||
virtual std::shared_ptr<Context> GetContext() = 0;
|
||||
|
||||
/// \brief Set context to run.
|
||||
///
|
||||
/// \param[in] context, context
|
||||
///
|
||||
/// \return void.
|
||||
virtual void SetContext(std::shared_ptr<Context> context) = 0;
|
||||
|
||||
/// \brief Get callback before kernel execution.
|
||||
///
|
||||
/// \return KernelCallBack pointer.
|
||||
virtual const KernelCallBack &GetKernelBeforeCallBack() = 0;
|
||||
|
||||
/// \brief Set callback before kernel execution.
|
||||
///
|
||||
/// \param[in] callback, callback function pointer
|
||||
///
|
||||
/// \return void.
|
||||
virtual void SetKernelBeforeCallBack(const KernelCallBack &callback) = 0;
|
||||
|
||||
/// \brief Get callback after kernel execution.
|
||||
///
|
||||
/// \return KernelCallBack pointer.
|
||||
virtual const KernelCallBack &GetKernelAfterCallBack() = 0;
|
||||
|
||||
/// \brief Set callback after kernel execution.
|
||||
///
|
||||
/// \param[in] callback, callback function pointer
|
||||
///
|
||||
/// \return void.
|
||||
virtual void SetKernelAfterCallBack(const KernelCallBack &callback) = 0;
|
||||
|
||||
/// \brief Convert Execution Plan to Kernel List
|
||||
///
|
||||
/// \return Kernel List
|
||||
virtual std::vector<Kernel *> ToKernelList() = 0;
|
||||
};
|
||||
} // namespace mindspore::infer::abstract
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "infer/execution_flow.h"
|
||||
#include "infer/tensor.h"
|
||||
|
||||
namespace mindspore::infer::abstract {
|
||||
class Executor : public std::enable_shared_from_this<Executor> {
|
||||
|
@ -38,7 +38,7 @@ class Executor : public std::enable_shared_from_this<Executor> {
|
|||
/// \param[in] execution_flow Abstract Execution Plan for execute.
|
||||
///
|
||||
/// \return Status.
|
||||
virtual Status Prepare(std::shared_ptr<ExecutionFlow> execution_flow) = 0;
|
||||
virtual Status Prepare() = 0;
|
||||
|
||||
/// \brief Execute According to ExecutionFlow.
|
||||
///
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INFER_GRAPH__RUNTIME_H_
|
||||
#define MINDSPORE_LITE_INFER_GRAPH__RUNTIME_H_
|
||||
#ifndef MINDSPORE_LITE_INFER_GRAPH_RUNTIME_H_
|
||||
#define MINDSPORE_LITE_INFER_GRAPH_RUNTIME_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
@ -50,6 +50,14 @@ class GraphRuntime : public std::enable_shared_from_this<GraphRuntime> {
|
|||
virtual Status Execute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
||||
KernelCallBack before = nullptr, KernelCallBack after = nullptr) = 0;
|
||||
|
||||
/// \brief Resize According to New Inputs and dims.
|
||||
///
|
||||
/// \param[in] inputs, inputs tensors to resize
|
||||
/// \param[in] dims, targe dim shape to resize
|
||||
///
|
||||
/// \return Status.
|
||||
virtual Status Resize(const std::vector<Tensor *> *inputs, const std::vector<std::vector<int64_t>> &dims) = 0;
|
||||
|
||||
/// \brief Get list of inputs for the model.
|
||||
///
|
||||
/// \return vector of Tensor.
|
||||
|
@ -62,4 +70,4 @@ class GraphRuntime : public std::enable_shared_from_this<GraphRuntime> {
|
|||
};
|
||||
} // namespace mindspore::infer::abstract
|
||||
|
||||
#endif // MINDSPORE_LITE_INFER_GRAPH__RUNTIME_H_
|
||||
#endif // MINDSPORE_LITE_INFER_GRAPH_RUNTIME_H_
|
||||
|
|
|
@ -13,19 +13,13 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_INFER_CALLBACK_H_
|
||||
#define MINDSPORE_LITE_INFER_CALLBACK_H_
|
||||
#ifndef MINDSPORE_LITE_INFER_KERNEL_CALLBACK_H_
|
||||
#define MINDSPORE_LITE_INFER_KERNEL_CALLBACK_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "executor/kernel_exec.h"
|
||||
|
||||
namespace mindspore::infer::abstract {
|
||||
using KernelCallBack = mindspore::lite::KernelCallBack;
|
||||
// class CallBack : public std::enable_shared_from_this<CallBack> {
|
||||
// public:
|
||||
// virtual ~CallBack() = default;
|
||||
// };
|
||||
} // namespace mindspore::infer::abstract
|
||||
|
||||
#endif // MINDSPORE_LITE_INFER_CALLBACK_H_
|
||||
#endif // MINDSPORE_LITE_INFER_KERNEL_CALLBACK_H_
|
||||
|
|
Loading…
Reference in New Issue