forked from OSSInnovation/mindspore
!14527 connect the process of actor runtime
From: @limingqi107 Reviewed-by: @cristoval,@wilfchen Signed-off-by: @wilfchen
This commit is contained in:
commit
08eb27287e
|
@ -50,6 +50,43 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
namespace {
|
||||
void TaskEmitActionForMindRT(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
// Get the mindRT backend.
|
||||
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
|
||||
auto mindrt_bc_ptr = std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr);
|
||||
MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
|
||||
|
||||
auto cut_list = compile::GetMsNonlinearOps();
|
||||
auto mindrt_compile = std::make_shared<compile::GraphCompiler>(mindrt_bc_ptr, cut_list);
|
||||
// The output of graph compiler is graph id.
|
||||
res->results()[kOutput] = mindrt_compile->CompileGraphs(res->func_graph());
|
||||
}
|
||||
|
||||
void ExecuteActionForMindRT(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (!res->results()[kOutput].is<GraphId>()) {
|
||||
MS_LOG(EXCEPTION) << "Execute args error";
|
||||
}
|
||||
auto graph_id = res->results()[kOutput].cast<GraphId>();
|
||||
|
||||
// Get the mindRT backend.
|
||||
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
|
||||
auto mindrt_bc_ptr = (std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr)).get();
|
||||
MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
|
||||
|
||||
// Construct the graph run function ptr.
|
||||
compile::VmEvalFuncPtr run =
|
||||
std::make_shared<compile::VmEvalFunc>([mindrt_bc_ptr, graph_id](const VectorRef &args) -> BaseRef {
|
||||
MS_LOG(INFO) << "Execute args size " << args.size();
|
||||
auto outs = mindrt_bc_ptr->RunGraph(graph_id, args);
|
||||
MS_LOG(DEBUG) << "out size " << outs.size();
|
||||
return outs[0];
|
||||
});
|
||||
res->results()[kOutput] = run;
|
||||
}
|
||||
} // namespace
|
||||
using CompileGraphs = compile::CompileGraphs;
|
||||
using abstract::AnalysisResult;
|
||||
using mindspore::abstract::AnalysisContextPtr;
|
||||
|
@ -488,6 +525,13 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
}
|
||||
}
|
||||
|
||||
// The graph compiling of mindRT.
|
||||
if ((backend == kMsConvert) && compile::IsMindRTUsed()) {
|
||||
TaskEmitActionForMindRT(res);
|
||||
return true;
|
||||
}
|
||||
|
||||
// The graph compiling of control sink.
|
||||
if (IsCtrlSink() && backend == kMsConvert) {
|
||||
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
|
||||
return true;
|
||||
|
@ -510,6 +554,14 @@ bool ExecuteAction(const ResourcePtr &res) {
|
|||
MS_LOG(EXCEPTION) << "Execute args error";
|
||||
}
|
||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
|
||||
// The graph running of mindRT.
|
||||
if ((backend == kMsConvert) && compile::IsMindRTUsed()) {
|
||||
ExecuteActionForMindRT(res);
|
||||
return true;
|
||||
}
|
||||
|
||||
// The graph running of control sink.
|
||||
if (IsCtrlSink() && backend == kMsConvert) {
|
||||
if (!res->results()[kOutput].is<GraphId>()) {
|
||||
MS_LOG(EXCEPTION) << "Execute args error";
|
||||
|
|
|
@ -25,7 +25,7 @@ void KernelActor::RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<Device
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
// When all the input data are collected, then allocate memory and callback launch.
|
||||
// When all the inputs are collected, then allocate memory and callback launch.
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
FetchOutputDeviceTensor();
|
||||
|
@ -38,7 +38,7 @@ void KernelActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *cont
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
// When all the input data are collected, then allocate memory and callback launch.
|
||||
// When all the inputs are collected, then allocate memory and callback launch.
|
||||
if (CheckLaunchCondition(context)) {
|
||||
FetchInputDeviceTensor(context);
|
||||
FetchOutputDeviceTensor();
|
||||
|
|
|
@ -26,7 +26,8 @@ namespace mindspore {
|
|||
namespace runtime {
|
||||
using mindspore::tensor::TensorPtr;
|
||||
|
||||
// Host tensor queue is used to store host tensors, and its data will be fetched by the host queue data source actor.
|
||||
// Host tensor queue is used to store host tensors(such as non weighted parameters of graph), and its data will be
|
||||
// fetched by the host queue data source actor.
|
||||
class HostTensorQueue {
|
||||
public:
|
||||
HostTensorQueue() = default;
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core/mindrt/include)
|
||||
|
||||
file(GLOB_RECURSE _VM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_VM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_VM)
|
||||
add_library(_mindspore_vm_obj OBJECT ${_VM_SRC_LIST})
|
||||
|
|
|
@ -26,6 +26,9 @@
|
|||
#include "utils/convert_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/framework/graph_compiler.h"
|
||||
#include "runtime/framework/graph_scheduler.h"
|
||||
#ifdef ENABLE_GE
|
||||
#include "utils/callbacks_ge.h"
|
||||
#endif
|
||||
|
@ -221,8 +224,58 @@ void MsBackend::ClearSessionGraphs() {
|
|||
target_sess_->ClearGraph();
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
void MsBackend::SetDebugger() { target_sess_->SetDebugger(); }
|
||||
#endif
|
||||
|
||||
MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id)
|
||||
: Backend(backend_name), device_name_(device_name), device_id_(device_id) {}
|
||||
|
||||
GraphId MindRTBackend::CompileGraph(const AnfNodePtrList &nodes) {
|
||||
// Get and set the device context.
|
||||
const auto &cur_device_name = GetCNodeTarget(nodes[0]);
|
||||
const auto &device_context =
|
||||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_});
|
||||
runtime::GraphCompiler::GetInstance().set_device_context(device_context);
|
||||
|
||||
// Transform nodes to inputs and outputs.
|
||||
FuncGraphPtr fg;
|
||||
AnfNodePtrList inputs;
|
||||
AnfNodePtrList outputs;
|
||||
std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(nodes);
|
||||
|
||||
// Compile graph.
|
||||
return runtime::GraphCompiler::GetInstance().CompileGraph(inputs, outputs);
|
||||
}
|
||||
|
||||
VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) {
|
||||
const auto &context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
|
||||
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
|
||||
return VectorRef();
|
||||
}
|
||||
|
||||
// Transform args to input tensors.
|
||||
std::vector<tensor::TensorPtr> inputs;
|
||||
for (const auto &arg : args) {
|
||||
PushInputTensor(arg, &inputs);
|
||||
}
|
||||
|
||||
// Fetch the kernel graph.
|
||||
const auto &kernel_graph = runtime::GraphCompiler::GetInstance().Fetch(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
||||
// Fetch the actor DAG.
|
||||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
|
||||
// Run actor DAG, wait interface of GraphScheduler to create outputs.
|
||||
VectorRef outputs;
|
||||
runtime::GraphScheduler::GetInstance().Run(actor_set);
|
||||
|
||||
return outputs;
|
||||
}
|
||||
} // namespace compile
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/contract.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -31,6 +32,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace compile {
|
||||
using OpRunInfo = session::OpRunInfo;
|
||||
|
||||
enum SwitchCondStatus {
|
||||
kCondOk = 0,
|
||||
kCondAlreadyRun,
|
||||
|
@ -85,6 +88,27 @@ class MsBackend : public Backend {
|
|||
std::string other_device_;
|
||||
std::unordered_map<GraphId, LinConvertResult> graph_id_map_;
|
||||
};
|
||||
|
||||
class MindRTBackend : public Backend {
|
||||
public:
|
||||
MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id);
|
||||
~MindRTBackend() override = default;
|
||||
|
||||
// Compile kernel graph from anf nodes list in the graph mode.
|
||||
GraphId CompileGraph(const AnfNodePtrList &nodes);
|
||||
// Compile single op kernel graph in the pyNative mode.
|
||||
GraphId CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask);
|
||||
|
||||
// Run Graph in the graph mode.
|
||||
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
|
||||
// Run Graph in the pyNative mode.
|
||||
VectorRef RunGraph(const GraphInfo &graph_info, const VectorRef &args);
|
||||
|
||||
private:
|
||||
std::string device_name_;
|
||||
uint32_t device_id_;
|
||||
};
|
||||
} // namespace compile
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
|
|
|
@ -521,6 +521,73 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
|
|||
return rt;
|
||||
}
|
||||
|
||||
GraphCompiler::GraphCompiler(const std::shared_ptr<MindRTBackend> &backend, const std::vector<PrimitivePtr> &cut_list)
|
||||
: backend_(backend) {
|
||||
MS_EXCEPTION_IF_NULL(backend_);
|
||||
if (backend_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The backend isn't created.";
|
||||
return;
|
||||
}
|
||||
graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend->name());
|
||||
}
|
||||
|
||||
uint32_t GraphCompiler::CompileGraphs(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphPtr root_graph = WrapPrimitives(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
|
||||
// Compile root graph.
|
||||
auto root_graph_id = CompileGraph(root_graph);
|
||||
|
||||
// Compile sub graphs.
|
||||
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
|
||||
for (auto sub_graph : sub_graphs) {
|
||||
if (sub_graph != func_graph && sub_graph != nullptr && !(sub_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
|
||||
(void)CompileGraph(sub_graph);
|
||||
}
|
||||
}
|
||||
|
||||
return root_graph_id;
|
||||
}
|
||||
|
||||
uint32_t GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_partition_);
|
||||
MS_EXCEPTION_IF_NULL(backend_);
|
||||
|
||||
// Split graph to segments.
|
||||
const auto &segments = graph_partition_->Partition(func_graph);
|
||||
MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size();
|
||||
|
||||
// Foreach the segments to compile graph.
|
||||
std::vector<uint32_t> graph_ids;
|
||||
for (const auto &segment : segments) {
|
||||
MS_EXCEPTION_IF_NULL(segment);
|
||||
// Compile the normal nodes, which doesn't contain the cut node.
|
||||
if (!segment->is_cut_) {
|
||||
if (segment->nodes_.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The segments size is 0.";
|
||||
}
|
||||
MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->fullname_with_scope();
|
||||
|
||||
// Compile the anfNodes list to kernelGraph, return the graph id of kernelGraph.
|
||||
auto graph_id = backend_->CompileGraph(segment->nodes_);
|
||||
graph_ids.emplace_back(graph_id);
|
||||
} else {
|
||||
// Compile the cut node.
|
||||
auto cut_node = segment->nodes_[0];
|
||||
MS_EXCEPTION_IF_NULL(cut_node);
|
||||
MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
return graph_ids[0];
|
||||
}
|
||||
|
||||
// Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
|
||||
// Return false in the transitional stage.
|
||||
bool IsMindRTUsed() { return false; }
|
||||
|
||||
BackendPtr CreateBackend() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
@ -533,7 +600,13 @@ BackendPtr CreateBackend() {
|
|||
if (name == kMsConvert) {
|
||||
std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto backend = std::make_shared<MsBackend>(name, target, device_id);
|
||||
BackendPtr backend = nullptr;
|
||||
// Create MindRTBackend or MsBackend according to whether mindrt is used.
|
||||
if (IsMindRTUsed()) {
|
||||
backend = std::make_shared<MindRTBackend>(name, target, device_id);
|
||||
} else {
|
||||
backend = std::make_shared<MsBackend>(name, target, device_id);
|
||||
}
|
||||
std::string device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device_target == kAscendDevice) {
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
|
|
|
@ -131,6 +131,30 @@ class CompileGraphs {
|
|||
BackendPtr backend_;
|
||||
};
|
||||
|
||||
// The graph compiling of using mindRT, which transforms the funcGraph to kernelGraph and returns the graph id of
|
||||
// kernelGraph.
|
||||
class GraphCompiler {
|
||||
public:
|
||||
GraphCompiler(const std::shared_ptr<MindRTBackend> &backend,
|
||||
const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
|
||||
~GraphCompiler() = default;
|
||||
|
||||
// The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs,
|
||||
// the return is the kernelGraph id of the root graph. It will traverse all subgraphs to call CompileGraph.
|
||||
uint32_t CompileGraphs(const FuncGraphPtr &root_graph);
|
||||
|
||||
private:
|
||||
// The parameter func_graph is a graph, it can be either a root graph or a sub graph,
|
||||
// the return is the corresponding kernelGraph id of the graph.
|
||||
uint32_t CompileGraph(const FuncGraphPtr &func_graph);
|
||||
|
||||
std::shared_ptr<MindRTBackend> backend_;
|
||||
GraphPartitionPtr graph_partition_;
|
||||
};
|
||||
|
||||
// Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
|
||||
bool IsMindRTUsed();
|
||||
|
||||
BackendPtr CreateBackend();
|
||||
|
||||
} // namespace compile
|
||||
|
|
Loading…
Reference in New Issue