!41056 [MS]Optimize backend

Merge pull request !41056 from 张学同/assert
This commit is contained in:
i-robot 2022-09-01 01:06:11 +00:00 committed by Gitee
commit bd229fcf11
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 934 additions and 830 deletions

View File

@ -33,6 +33,8 @@
"mindspore/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc" "knownConditionTrueFalse"
"mindspore/mindspore/ccsrc/backend/graph_compiler/backend.cc" "knownConditionTrueFalse"
"mindspore/mindspore/ccsrc/backend/graph_compiler/backend.cc" "variableScope"
"mindspore/mindspore/ccsrc/backend/graph_compiler/backend_base.cc" "knownConditionTrueFalse"
"mindspore/mindspore/ccsrc/backend/graph_compiler/backend_base.cc" "variableScope"
"mindspore/mindspore/core/ops/max_pool.cc" "zerodivcond"
# MindData

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -52,17 +52,6 @@
namespace mindspore {
namespace compile {
bool Backend::GetCond(const BaseRef &c, bool *value) {
mindspore::ScopedLongRunning long_running;
return BaseRefToBool(c, value);
}
bool Backend::GetIndex(const BaseRef &c, int64_t *value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
Backend::Backend(const std::string &name) : name_(name), is_multi_graph_sink_(false) {
MS_LOG(DEBUG) << "Select backend:" << name;
convert_fn_ = MsVmConvert;
}
LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) {
MS_LOG(DEBUG) << "MsConvert";
MS_EXCEPTION_IF_NULL(segment);
@ -154,142 +143,6 @@ std::vector<tensor::TensorPtr> GetTensorWithoutValueMask(const session::BackendO
return tensors_without_value_node;
}
void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) {
MS_EXCEPTION_IF_NULL(inputs);
if (utils::isa<tensor::TensorPtr>(arg)) {
auto value = utils::cast<tensor::TensorPtr>(arg);
inputs->push_back(value);
} else if (utils::isa<ValuePtr>(arg)) {
auto value = utils::cast<ValuePtr>(arg);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
auto tuple_value = value_tuple->value();
(void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs),
[](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
} else if (value->isa<Scalar>()) {
tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
inputs->push_back(scalar_tensor);
} else if (value->isa<Monad>()) {
// If value is a monad, replace it with an unused tensor.
inputs->push_back(std::make_shared<tensor::Tensor>(int64_t(0), kBool));
} else {
inputs->push_back(value->cast<tensor::TensorPtr>());
}
} else if (utils::isa<PyObjectRef>(arg)) {
auto value = utils::cast<PyObjectRef>(arg).object_;
inputs->push_back(py::cast<tensor::TensorPtr>(value));
} else if (utils::isa<VectorRefPtr>(arg)) {
const auto &args_new = utils::cast<VectorRef>(arg);
for (const auto &v : args_new) {
PushInputTensor(v, inputs);
}
} else {
MS_LOG(WARNING) << "Invalid input type.";
}
}
// Move these function to anonymous namespace
void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) {
MS_EXCEPTION_IF_NULL(flatted_value);
for (auto value_element : value) {
MS_EXCEPTION_IF_NULL(value_element);
if (utils::isa<tensor::TensorPtr>(value_element)) {
(void)flatted_value->emplace_back(value_element);
} else if (utils::isa<ValueTuplePtr>(value_element)) {
auto value_tuple_element = value_element->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple_element);
FlatValueTupleValue(value_tuple_element->value(), flatted_value);
} else {
MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple.";
}
}
}
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
MS_EXCEPTION_IF_NULL(flatted_value);
if (utils::isa<ValueSequencePtr>(arg)) {
auto value_sequence = utils::cast<ValueSequencePtr>(arg);
MS_EXCEPTION_IF_NULL(value_sequence);
auto sequence_value = value_sequence->value();
for (auto &value : sequence_value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else if (utils::isa<ValueDictionaryPtr>(arg)) {
auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
MS_EXCEPTION_IF_NULL(value_dict);
auto dict_value = value_dict->value();
for (auto &iter : dict_value) {
auto value = iter.second;
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else if (utils::isa<tensor::COOTensorPtr>(arg)) {
auto coo_tensor = utils::cast<tensor::COOTensorPtr>(arg);
MS_EXCEPTION_IF_NULL(coo_tensor);
for (size_t i = 0; i < coo_tensor->GetTensorLength(); ++i) {
(void)flatted_value->emplace_back(coo_tensor->GetTensorAt(i));
}
} else if (utils::isa<tensor::CSRTensorPtr>(arg)) {
auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(arg);
MS_EXCEPTION_IF_NULL(csr_tensor);
for (size_t i = 0; i < csr_tensor->GetTensorLength(); ++i) {
(void)flatted_value->emplace_back(csr_tensor->GetTensorAt(i));
}
} else {
MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is "
<< arg.ToString();
}
}
// Insert the front_node related tensor in the input_tensor.
void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
std::vector<tensor::TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(input_tensors);
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
if (iter == parameters.end()) {
(void)((*input_tensors).emplace_back(nullptr));
return;
}
auto position = iter - parameters.begin();
PushInputTensor(args[position], input_tensors);
}
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
size_t index, std::vector<tensor::TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(input_tensors);
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
const size_t position = iter - parameters.begin();
// If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
// and there is no need to input a tensor.
if (position >= args.size()) {
MS_LOG(DEBUG) << "Position out of args range, position value is " << position << " and args size is " << args.size()
<< ".";
(void)input_tensors->emplace_back(nullptr);
return;
}
ValuePtrList flatted_value_tuple_value;
FlattenValue(args[position], &flatted_value_tuple_value);
if (index >= flatted_value_tuple_value.size()) {
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";
}
auto input = flatted_value_tuple_value[index];
MS_EXCEPTION_IF_NULL(input);
auto tensor_input = input->cast<tensor::TensorPtr>();
input_tensors->push_back(tensor_input);
}
void UpdateOutputAbstract(const KernelGraphPtr &kernel_graph, const session::BackendOpRunInfoPtr &op_run_info) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(op_run_info);
@ -302,35 +155,6 @@ void UpdateOutputAbstract(const KernelGraphPtr &kernel_graph, const session::Bac
}
}
TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index) {
MS_EXCEPTION_IF_NULL(output_node);
// Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
// when infer type is not equal to device type.
auto type_id = common::AnfAlgo::GetOutputInferDataType(output_node, output_index);
const auto &shape = common::AnfAlgo::GetOutputInferShape(output_node, output_index);
auto tensor = std::make_shared<tensor::Tensor>(type_id, shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));
// Put device tensor into host tensor.
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
MS_EXCEPTION_IF_NULL(device_tensor);
device_tensor->SetNodeIndex(output_node, output_index);
tensor->set_device_address(device_tensor);
tensor->set_sync_status(kNeedSyncDeviceToHost);
// MindRT is disabled in the multi graphs scenario
// Delete tensor->data_sync() when MindRT is enabled in all scenes.
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
// If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
// Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
tensor->data_sync(false);
}
return tensor;
}
device::DeviceAddressPtr CloneEmptyDeviceAddress(const device::DeviceAddressPtr &old_device_address,
const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(old_device_address);
@ -413,47 +237,6 @@ bool EnablePyNativeSyncRunning() {
MS_EXCEPTION_IF_NULL(ms_context);
return ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
}
std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompilerInfo &graph_compiler_info,
const VectorRef &args) {
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
std::vector<std::vector<tensor::TensorPtr>> input_tensor_lists;
for (const auto &kernel_graph : graph_compiler_info.graphs_) {
std::vector<tensor::TensorPtr> input_tensors;
MS_EXCEPTION_IF_NULL(kernel_graph);
for (const auto &input_node : kernel_graph->input_nodes()) {
auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node);
if (element_pair.first) {
PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &input_tensors);
} else {
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
PushTensor(args, origin_parameters, front_node, &input_tensors);
}
}
(void)input_tensor_lists.emplace_back(input_tensors);
}
// Input tensors of the control node.
std::vector<tensor::TensorPtr> input_tensors;
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
// Get inputs of control node which come from the host actor.
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
for (const auto &parameter_with_index : control_node_parameters) {
const auto &parameter = parameter_with_index.first;
MS_EXCEPTION_IF_NULL(parameter);
const auto &abs = parameter->abstract();
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTuple>()) {
MS_LOG(DEBUG) << "Fetch input tensor for tuple parameter:" << parameter->DebugString() << " in control flow.";
PushTupleTensor(args, origin_parameters, parameter, parameter_with_index.second, &input_tensors);
} else {
PushTensor(args, origin_parameters, parameter, &input_tensors);
}
}
(void)input_tensor_lists.emplace_back(input_tensors);
return input_tensor_lists;
}
} // namespace
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
@ -540,194 +323,6 @@ void MsBackend::SetDebugger() {
}
#endif
MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id)
: Backend(backend_name), device_name_(device_name) {
root_graph_ = nullptr;
auto ms_context = MsContext::GetInstance();
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
auto &cut_list = pynative_mode ? GetControlOps() : GetMsNonlinearOps();
graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
graph_compiler_ = std::make_shared<GraphCompiler>();
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
device_context->Initialize();
device_id_ = device_context->device_context_key().device_id_;
#ifdef ENABLE_DEBUGGER
SetDebuggerInit();
#endif
runtime::GraphScheduler::GetInstance().Initialize();
}
void MindRTBackend::ProcessNotSupportCnode(const FuncGraphPtr &func_graph,
const mindspore::device::DeviceType &old_target,
const mindspore::device::DeviceType &new_target) const {
const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
for (const auto &node : all_nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!common::AnfAlgo::HasNodeAttr(kAttrNotSupportOpForDevice, cnode)) {
continue;
}
auto not_support_device = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrNotSupportOpForDevice);
if (device::GetDeviceTypeByName(not_support_device) != old_target) {
continue;
}
common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(device::GetDeviceNameByType(new_target)), node);
}
}
const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(graph_compiler_);
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << "Status record: start compile function graph: " << func_graph->ToString();
PROF_START(compile_func_graph);
auto root_graph = WrapPrimitives(func_graph);
MS_EXCEPTION_IF_NULL(root_graph);
root_graph_ = root_graph;
// Register a summary callback function, which is called in the final stages of summary.
graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
real_execution_mode_ = ms_execution_mode_;
func_graph->set_flag(kFlagPyNativeRunInGraph, real_execution_mode_ == kPynativeMode);
// Compile root graph.
graph_id_to_device_context_.clear();
func_graph_to_kernel_graph_ids_.clear();
control_nodes_.clear();
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
MS_EXCEPTION_IF_NULL(device_context);
bool all_support = device_context->PartitionGraph(func_graph);
if (all_support) {
auto run_mode = device_context->GetRunMode(func_graph);
if (run_mode == device::RunMode::kGraphMode && pynative::GraphAdapter::PyNativeEnableTaskSink(func_graph)) {
auto graph_id = graph_compiler_->CompileWholeGraphForGraphRunMode(func_graph, device_context);
graph_id_to_device_context_[graph_id] = device_context;
} else {
CompileSubGraph(func_graph, device::RunMode::kKernelMode);
}
} else {
ProcessNotSupportCnode(func_graph, device_context->GetDeviceType(), mindspore::device::DeviceType::kCPU);
CompileSubGraph(func_graph);
}
// Construct the graph compiler info.
auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
MS_EXCEPTION_IF_NULL(graph_compiler_info);
if (real_execution_mode_ == kGraphMode &&
((!graph_compiler_info->graphs_.empty()) || graph_compiler_info->control_nodes_.size() > 1)) {
// Transform graph to actor DAG, and schedule the actor DAG.
ParseControlNodes(*graph_compiler_info);
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
}
const ActorInfo &actor_info = graph_compiler_info->name_;
(void)actor_to_graph_compiler_info_.emplace(graph_compiler_info->name_, std::move(graph_compiler_info));
PROF_END(compile_func_graph);
if (ms_execution_mode_ != real_execution_mode_) {
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
}
MS_LOG(INFO) << "Status record: end compile function graph: " << func_graph->ToString()
<< ", produce actor: " << actor_info;
return actor_info;
}
void MindRTBackend::CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
auto root_graph = WrapPrimitives(func_graph);
MS_EXCEPTION_IF_NULL(root_graph);
CompileGraph(root_graph, run_mode);
MS_EXCEPTION_IF_NULL(root_graph->manager());
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
for (const auto &sub_graph : sub_graphs) {
if (sub_graph != func_graph && sub_graph != nullptr) {
CompileGraph(sub_graph, run_mode);
}
}
}
void MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(graph_partition_);
MS_EXCEPTION_IF_NULL(graph_compiler_);
bool contain_multi_target = false;
// Split graph to segments.
const auto &segments = graph_partition_->Partition(func_graph, &contain_multi_target);
MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size();
// Foreach the segments to compile graph.
for (const auto &segment : segments) {
CompileGraph(segment, run_mode);
}
}
void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, device::RunMode run_mode) {
MS_EXCEPTION_IF_NULL(segment);
// Compile the normal nodes, which doesn't contain the cut node.
if (segment->nodes_.size() == 0) {
MS_LOG(EXCEPTION) << "The segments size is 0.";
}
if (!segment->is_cut_) {
MS_EXCEPTION_IF_NULL(segment->nodes_[0]);
MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->DebugString();
// Get the device context.
const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_});
MS_EXCEPTION_IF_NULL(device_context);
device_context->Initialize();
// Transform nodes to inputs and outputs.
FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// Compile graph.
auto graph_id =
graph_compiler_->CompileGraph(segment, outputs, device_context, run_mode, real_execution_mode_ == kPynativeMode);
graph_id_to_device_context_[graph_id] = device_context;
const auto &func_graph = segment->nodes_[0]->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph_to_kernel_graph_ids_.find(func_graph) == func_graph_to_kernel_graph_ids_.end()) {
(void)func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>{graph_id});
} else {
(void)func_graph_to_kernel_graph_ids_[func_graph].back().emplace_back(graph_id);
}
} else {
// Compile the cut node.
auto cut_node = segment->nodes_[0];
MS_EXCEPTION_IF_NULL(cut_node);
MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->DebugString();
control_nodes_.push_back(cut_node);
if (common::AnfAlgo::IsCallNode(cut_node) || common::AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitch) ||
common::AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitchLayer)) {
const auto &func_graph = cut_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
(void)func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>());
}
}
}
namespace {
void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, const CNodePtr &front_cnode,
const CNodePtr &backend_cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
@ -878,101 +473,36 @@ void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, co
}
}
void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(value);
MS_EXCEPTION_IF_NULL(outputs);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
for (size_t i = 0; i < value_tuple->size(); ++i) {
ValuePtr element = value_tuple->value()[i];
MS_EXCEPTION_IF_NULL(element);
if (element->isa<tensor::Tensor>()) {
auto tensor = element->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
outputs->emplace_back(tensor);
} else if (element->isa<ValueTuple>()) {
VectorRef tuple;
TensorValueToVector(element, &tuple);
outputs->emplace_back(tuple);
}
}
} else if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
outputs->emplace_back(tensor);
}
TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index) {
MS_EXCEPTION_IF_NULL(output_node);
// Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
// when infer type is not equal to device type.
auto type_id = common::AnfAlgo::GetOutputInferDataType(output_node, output_index);
const auto &shape = common::AnfAlgo::GetOutputInferShape(output_node, output_index);
auto tensor = std::make_shared<tensor::Tensor>(type_id, shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));
// Put device tensor into host tensor.
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
MS_EXCEPTION_IF_NULL(device_tensor);
device_tensor->SetNodeIndex(output_node, output_index);
tensor->set_device_address(device_tensor);
tensor->set_sync_status(kNeedSyncDeviceToHost);
// MindRT is disabled in the multi graphs scenario
// Delete tensor->data_sync() when MindRT is enabled in all scenes.
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
// If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
// Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
tensor->data_sync(false);
}
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const VectorRef &args, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(graph_output);
MS_EXCEPTION_IF_NULL(outputs);
if (graph_output->isa<ValueNode>()) {
MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
VectorRef output_tmp;
ValuePtr value = GetValueNode(graph_output);
TensorValueToVector(value, &output_tmp);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
outputs->emplace_back(output_tmp);
} else if (value->isa<tensor::Tensor>()) {
*outputs = output_tmp;
} else {
MS_LOG(INFO) << "Graph output is empty!";
}
return true;
}
if (graph_output->isa<Parameter>()) {
MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
// Find the right parameter as ret_val.
auto func_graph = graph_output->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto params = func_graph->parameters();
if (args.size() != params.size()) {
MS_LOG(EXCEPTION) << "Input size " << args.size() << " not equal to graph input size " << params.size();
}
auto it = std::find(params.begin(), params.end(), graph_output);
if (it == params.end()) {
MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
}
size_t index = it - params.cbegin();
if (index >= args.size()) {
MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size();
}
outputs->emplace_back(args[index]);
return true;
}
return false;
return tensor;
}
} // namespace
void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph) {
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
bool is_embedding_cache_server = false;
#ifdef WITH_BACKEND
is_embedding_cache_server = ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server();
#endif
if (need_contruct_output) {
// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.
if (!is_embedding_cache_server) {
actor_set->output_actor_->UpdateOutputDeviceAddress();
}
// Fetch outputs.
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
auto &output_tensors = actor_set->output_actor_->outputs();
if (!output_tensors.empty()) {
size_t output_position = 0;
ConstructOutputs(root_graph->output(), output_tensors, &output_position, outputs);
}
}
}
void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const VectorRef &args, VectorRef *outputs) {
WaitTaskFinish();
@ -1150,176 +680,6 @@ void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const Graph
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
}
void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(root_graph_);
if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) {
return;
}
const auto &context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return;
}
// Open abstract_lock for dynamic_shape
AnfUtils::OpenAbstractLock();
MS_LOG(INFO) << "Status record: start run actor: " << actor_info;
// Fetch the graph compiler info.
const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
if (graph_iter == actor_to_graph_compiler_info_.end()) {
MS_LOG(EXCEPTION) << "Can't find the graph compiler info.";
}
MS_EXCEPTION_IF_NULL(graph_iter->second);
const auto &graph_compiler_info = *(graph_iter->second);
// For pynative and graph mix execution.
WaitTaskFinish();
// Run in the pynative mode.
MS_EXCEPTION_IF_NULL(outputs);
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
if (real_execution_mode_ == kPynativeMode) {
RunGraphByCondition(actor_info, graph_compiler_info, args, outputs);
return;
}
auto input_tensors = GetRunGraphInputs(graph_compiler_info, args);
// Release python gil.
mindspore::ScopedLongRunning long_running;
// Run actor DAG.
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
MS_EXCEPTION_IF_NULL(actor_set);
runtime::GraphScheduler::GetInstance().Run(actor_set, input_tensors);
MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->Summary(graph_compiler_info.graphs_);
ConstructOutputs(actor_set, outputs, root_graph_);
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
// Close abstract_lock for dynamic_shape
AnfUtils::CloseAbstractLock();
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
}
BaseRef MindRTBackend::ConstructOutputByAbstract(const abstract::AbstractBasePtr &abstract,
const std::vector<tensor::TensorPtr> &output_tensors,
size_t *output_position) {
MS_EXCEPTION_IF_NULL(abstract);
MS_EXCEPTION_IF_NULL(output_position);
size_t outputs_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
if (*output_position + outputs_num > output_tensors.size()) {
MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position << " need:" << outputs_num
<< " total:" << output_tensors.size();
}
VectorRef outputs;
if (!abstract->isa<abstract::AbstractTuple>()) {
(*output_position)++;
return output_tensors[(*output_position) - 1];
}
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
for (const auto &sub_abstract : sub_abstracts) {
MS_EXCEPTION_IF_NULL(sub_abstract);
outputs.emplace_back(ConstructOutputByAbstract(sub_abstract, output_tensors, output_position));
}
return outputs;
}
void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node,
const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position,
VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(output_node);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(output_position);
const PrimitiveSet expand_prims{
prim::kPrimMakeTuple,
prim::kPrimMakeCSRTensor,
prim::kPrimMakeCOOTensor,
prim::kPrimMakeRowTensor,
};
// The MakeTuple/MakeSaprse node need expand and recurse.
if (IsOneOfPrimitiveCNode(output_node, expand_prims)) {
auto make_tuple = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
VectorRef make_tuple_output;
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
ConstructOutputs(make_tuple->input(i), output_tensors, output_position, &make_tuple_output);
}
outputs->emplace_back(std::move(make_tuple_output));
return;
}
// The depend node need get the real node.
if (common::AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
auto depend_node = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_node);
ConstructOutputs(depend_node->input(kRealInputIndexInDepend), output_tensors, output_position, outputs);
return;
}
auto outputs_num = common::AnfAlgo::GetOutputTensorNum(output_node);
// The value node uses the value to be output, to avoid the host memory of value free due to value node destruction.
if (output_node->isa<ValueNode>()) {
auto value = output_node->cast<ValueNodePtr>()->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
outputs->emplace_back(value);
(*output_position) += CountValueNum(value->cast<ValueTuplePtr>());
} else if (outputs_num != 0) {
outputs->emplace_back(value);
(*output_position) += outputs_num;
}
// The empty value node return the empty VectorRef.
return;
}
if (common::AnfAlgo::IsCallNode(output_node)) {
auto abstract = output_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
outputs->emplace_back(ConstructOutputByAbstract(abstract, output_tensors, output_position));
return;
}
auto &output_abstract = output_node->abstract();
MS_EXCEPTION_IF_NULL(output_abstract);
// Wrap output to VectorRef if the output is tuple.
if (output_abstract->isa<abstract::AbstractTuple>()) {
VectorRef output_tuple;
for (size_t i = 0; i < outputs_num; ++i) {
if (*output_position >= output_tensors.size()) {
MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position;
}
output_tuple.emplace_back(std::move(output_tensors[*output_position]));
++(*output_position);
}
outputs->emplace_back(std::move(output_tuple));
} else {
for (size_t i = 0; i < outputs_num; ++i) {
if (*output_position >= output_tensors.size()) {
MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position;
}
outputs->emplace_back(std::move(output_tensors[*output_position]));
++(*output_position);
}
}
}
#ifdef ENABLE_DEBUGGER
void MindRTBackend::SetDebuggerInit() {
auto debugger_ = Debugger::GetInstance();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
}
#endif
void MindRTBackend::WaitTaskFinish() const { runtime::OpExecutor::GetInstance().Wait(); }
void MindRTBackend::ClearOpExecutorResource() const { runtime::OpExecutor::GetInstance().Reset(); }
@ -1334,54 +694,6 @@ void MindRTBackend::SyncStream() {
}
}
std::shared_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {
MS_EXCEPTION_IF_NULL(root_graph);
MS_EXCEPTION_IF_NULL(graph_compiler_);
std::vector<KernelGraphPtr> graphs;
std::vector<DeviceContext *> device_contexts;
std::string name = "kernel_graph";
size_t graph_index = 0;
for (const auto &graph_id_to_context : graph_id_to_device_context_) {
(void)graphs.emplace_back(graph_compiler_->Fetch(graph_id_to_context.first));
(void)device_contexts.emplace_back(graph_id_to_context.second);
if (graph_index == 0) {
(void)name.append("_").append(std::to_string(graph_id_to_context.first));
} else if (graph_index == graph_id_to_device_context_.size() - 1) {
(void)name.append("-").append(std::to_string(graph_id_to_context.first));
}
++graph_index;
}
auto parser = std::make_shared<ControlNodeParser>();
runtime::KernelMapPosition outputs_order;
const auto &root_output =
common::AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
size_t position = 0;
auto outputs = common::AnfAlgo::GetAllOutputWithIndex(root_output);
size_t outputs_num = outputs.size();
for (const auto &output : outputs) {
if (outputs_order.count(output) == 0) {
outputs_order[output] = {position++};
} else {
(void)outputs_order[output].emplace_back(position++);
}
}
std::vector<std::vector<int64_t> *> tensors_mask;
std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
auto strategy = runtime::GraphExecutionStrategy::kPipeline;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) != kOptimizeO0) {
strategy = runtime::GraphExecutionStrategy::kPipelineWithExecutionOrder;
}
return std::make_shared<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
root_graph->parameters(), parser, outputs_order, outputs_num, name, false,
strategy);
}
void MindRTBackend::EraseSingleOpCache(const GraphInfo &graph_info) {
pynative::OpCompiler::GetInstance().ClearOpCache(graph_info);
}
@ -1617,25 +929,5 @@ void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &ou
outputs->emplace_back(output_tensor);
}
}
void MindRTBackend::ParseControlNodes(const GraphCompilerInfo &graph_compile_info) {
FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs;
for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
const auto &func_graph = func_graph_to_kernel_graph_ids.first;
for (const auto &sub_kernel_graphs_ids : func_graph_to_kernel_graph_ids.second) {
std::vector<KernelGraphPtr> kernel_graphs;
for (const auto &graph_id : sub_kernel_graphs_ids) {
const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
(void)kernel_graphs.emplace_back(kernel_graph);
}
(void)func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graphs);
}
}
graph_compile_info.control_node_parser_->Parse(control_nodes_, graph_compile_info.graphs_,
graph_compile_info.device_contexts_, root_graph_,
func_graph_to_kernel_graphs);
}
} // namespace compile
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,7 @@
#include "utils/hash_map.h"
#include "include/common/utils/contract.h"
#include "ir/anf.h"
#include "backend/graph_compiler/backend_base.h"
#include "backend/graph_compiler/segment_runner.h"
#include "backend/graph_compiler/graph_partition.h"
#include "backend/graph_compiler/vm.h"
@ -39,43 +40,6 @@
namespace mindspore {
namespace compile {
using GraphOutputInfo = session::GraphOutputInfo;
using DeviceContext = device::DeviceContext;
using ActorInfo = runtime::ActorInfo;
using GraphCompiler = runtime::GraphCompiler;
using GraphCompilerInfo = runtime::GraphCompilerInfo;
using ControlNodeParser = runtime::ControlNodeParser;
using FuncGraphToKernelGraphGroup = runtime::FuncGraphToKernelGraphGroup;
using ControlNodeParserPtr = runtime::ControlNodeParserPtr;
using KernelWithIndex = session::KernelWithIndex;
enum SwitchCondStatus {
kCondOk = 0,
kCondAlreadyRun,
};
class BACKEND_EXPORT Backend {
public:
explicit Backend(const std::string &name);
virtual ~Backend() = default;
LinkFuncType convert_fn() { return convert_fn_; }
std::string name() { return name_; }
virtual bool GetCond(const BaseRef &c, bool *value);
virtual bool GetIndex(const BaseRef &c, int64_t *value);
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
virtual void SetDebugger() {}
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; }
protected:
std::string name_;
LinkFuncType convert_fn_;
bool is_multi_graph_sink_;
};
class BACKEND_EXPORT MsBackend : public Backend {
public:
MsBackend(const std::string &name, const std::string &target, uint32_t device_id);
@ -102,59 +66,30 @@ class BACKEND_EXPORT MsBackend : public Backend {
mindspore::HashMap<GraphId, LinConvertResult> graph_id_map_;
};
class BACKEND_EXPORT MindRTBackend : public Backend {
class BACKEND_EXPORT MindRTBackend : public MindRTBackendBase {
public:
MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id);
MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id)
: MindRTBackendBase(backend_name, device_name, device_id) {}
~MindRTBackend() override = default;
// The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs, It will traverse
// all sub graphs to call CompileGraph.
const ActorInfo &CompileGraphs(const FuncGraphPtr &func_graph);
// Run Graph in the graph mode.
void RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs);
// Run single op in the PyNative mode.
void RunOp(const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs);
#ifdef ENABLE_DEBUGGER
void SetDebuggerInit();
#endif
// Execute all tasks in queue when lazy build is enabled in PyNative mode.
void WaitTaskFinish() const;
void WaitTaskFinish() const override;
// Clear resource when python exit.
void ClearOpExecutorResource() const;
// Get the device target.
std::string GetDeviceTarget() { return device_name_; }
// Sync default stream in PyNative mode.
void SyncStream();
private:
// The parameter func_graph is a graph, it can be either a root graph or a sub graph,
// The result of graph compiler is stored in graph_id_to_device_context_ and control_nodes_.
void CompileGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode);
// Compile the kernel graph by the segment which is from the function graph partition.
void CompileGraph(const GraphSegmentPtr &segment, device::RunMode run_mode);
// CreateKernel, Transform and Schedule have not been finished when LazyBuild is enabled in PyNative mode.
void CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
// Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode.
void CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpBuildTask>> &build_tasks);
void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph);
// Restore the outputs tuple by the origin funcGraph output node and output tensors.
void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors,
size_t *output_position, VectorRef *outputs);
// In the control flow, the output of the call node needs to be created by abstract.
BaseRef ConstructOutputByAbstract(const abstract::AbstractBasePtr &abstract,
const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position);
// Construct the GraphCompilerInfo by the compilation results of graph, used in Graph mode.
std::shared_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph);
void ParseControlNodes(const GraphCompilerInfo &graph_compile_info);
// In PyNative mode, the size of single op cache list will be increasing, which lead to memory cost increasing,
// so the latest single op cache should be erased when cache list size exceeds threshold value.
void EraseSingleOpCache(const GraphInfo &graph_info);
@ -171,48 +106,24 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
const session::BackendOpRunInfoPtr &op_run_info);
void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const VectorRef &args, VectorRef *outputs);
const VectorRef &args, VectorRef *outputs) override;
// Split complete kernel graph to single op graph in PyNative back
// propagation, then compile and run single op graph.
void RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_info, const VectorRef &args, VectorRef *outputs);
void RunGraphByActors(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const VectorRef &args, VectorRef *outputs);
void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs);
void ReleaseForwardOutput(const std::vector<TensorPtr> &input_tensors);
void OpRunCallback(const std::shared_ptr<runtime::OpTaskContext> &context);
// When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several
// node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to
// the corresponding device_context.
std::map<GraphId, DeviceContext *> graph_id_to_device_context_;
// Funcgraph will be cut into multiple kernel graphs, and the map is used to save the correspondence.
// The kernel graphs which not cut by control flow are placed in the same group.
std::map<FuncGraphPtr, std::vector<std::vector<GraphId>>> func_graph_to_kernel_graph_ids_;
std::map<GraphInfo, DeviceContext *> graph_info_to_device_context_;
std::vector<AnfNodePtr> control_nodes_;
mindspore::HashMap<ActorInfo, std::shared_ptr<GraphCompilerInfo>> actor_to_graph_compiler_info_;
// Cache output tensor ref count of kernels for back propagation graph in PyNative mode.
std::map<GraphId, std::map<KernelWithIndex, size_t>> cnode_ref_counts_;
// Cache forward op output value node tensor ref count of kernels for back propagation graph in PyNative mode.
std::map<std::string, size_t> forward_op_output_tensor_id_;
FuncGraphPtr root_graph_;
GraphPartitionPtr graph_partition_;
std::shared_ptr<GraphCompiler> graph_compiler_;
std::string device_name_;
uint32_t device_id_;
int ms_execution_mode_{kGraphMode};
int real_execution_mode_{kGraphMode};
void CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode = device::RunMode::kUnknown);
void ProcessNotSupportCnode(const FuncGraphPtr &func_graph, const device::DeviceType &old_target,
const device::DeviceType &new_target) const;
};
using MindRTBackendPtr = std::shared_ptr<compile::MindRTBackend>;
} // namespace compile

View File

@ -0,0 +1,753 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/graph_compiler/backend_base.h"
#include <algorithm>
#include <vector>
#include <map>
#include "backend/graph_compiler/transform.h"
#include "ir/anf.h"
#include "utils/log_adapter.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "runtime/pynative/graph_adapter.h"
#include "distributed/recovery/recovery_context.h"
#include "include/common/utils/scoped_long_running.h"
#include "include/common/utils/callbacks.h"
#ifdef ENABLE_DEBUGGER
#include "debug/debugger/debugger.h"
#endif
#ifdef WITH_BACKEND
#include "ps/ps_context.h"
#endif
namespace mindspore {
namespace compile {
bool Backend::GetCond(const BaseRef &c, bool *value) {
mindspore::ScopedLongRunning long_running;
return BaseRefToBool(c, value);
}
bool Backend::GetIndex(const BaseRef &c, int64_t *value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
Backend::Backend(const std::string &name) : name_(name), is_multi_graph_sink_(false) {
MS_LOG(DEBUG) << "Select backend:" << name;
convert_fn_ = MsVmConvert;
}
void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) {
MS_EXCEPTION_IF_NULL(inputs);
if (utils::isa<tensor::TensorPtr>(arg)) {
auto value = utils::cast<tensor::TensorPtr>(arg);
inputs->push_back(value);
} else if (utils::isa<ValuePtr>(arg)) {
auto value = utils::cast<ValuePtr>(arg);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
auto tuple_value = value_tuple->value();
(void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs),
[](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
} else if (value->isa<Scalar>()) {
tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
inputs->push_back(scalar_tensor);
} else if (value->isa<Monad>()) {
// If value is a monad, replace it with an unused tensor.
inputs->push_back(std::make_shared<tensor::Tensor>(int64_t(0), kBool));
} else {
inputs->push_back(value->cast<tensor::TensorPtr>());
}
} else if (utils::isa<PyObjectRef>(arg)) {
auto value = utils::cast<PyObjectRef>(arg).object_;
inputs->push_back(py::cast<tensor::TensorPtr>(value));
} else if (utils::isa<VectorRefPtr>(arg)) {
const auto &args_new = utils::cast<VectorRef>(arg);
for (const auto &v : args_new) {
PushInputTensor(v, inputs);
}
} else {
MS_LOG(WARNING) << "Invalid input type.";
}
}
namespace {
// Move these function to anonymous namespace
void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) {
MS_EXCEPTION_IF_NULL(flatted_value);
for (auto value_element : value) {
MS_EXCEPTION_IF_NULL(value_element);
if (utils::isa<tensor::TensorPtr>(value_element)) {
(void)flatted_value->emplace_back(value_element);
} else if (utils::isa<ValueTuplePtr>(value_element)) {
auto value_tuple_element = value_element->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple_element);
FlatValueTupleValue(value_tuple_element->value(), flatted_value);
} else {
MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple.";
}
}
}
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
MS_EXCEPTION_IF_NULL(flatted_value);
if (utils::isa<ValueSequencePtr>(arg)) {
auto value_sequence = utils::cast<ValueSequencePtr>(arg);
MS_EXCEPTION_IF_NULL(value_sequence);
auto sequence_value = value_sequence->value();
for (auto &value : sequence_value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else if (utils::isa<ValueDictionaryPtr>(arg)) {
auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
MS_EXCEPTION_IF_NULL(value_dict);
auto dict_value = value_dict->value();
for (auto &iter : dict_value) {
auto value = iter.second;
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else if (utils::isa<tensor::COOTensorPtr>(arg)) {
auto coo_tensor = utils::cast<tensor::COOTensorPtr>(arg);
MS_EXCEPTION_IF_NULL(coo_tensor);
for (size_t i = 0; i < coo_tensor->GetTensorLength(); ++i) {
(void)flatted_value->emplace_back(coo_tensor->GetTensorAt(i));
}
} else if (utils::isa<tensor::CSRTensorPtr>(arg)) {
auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(arg);
MS_EXCEPTION_IF_NULL(csr_tensor);
for (size_t i = 0; i < csr_tensor->GetTensorLength(); ++i) {
(void)flatted_value->emplace_back(csr_tensor->GetTensorAt(i));
}
} else {
MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is "
<< arg.ToString();
}
}
// Insert the front_node related tensor in the input_tensor.
void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
std::vector<tensor::TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(input_tensors);
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
if (iter == parameters.end()) {
(void)((*input_tensors).emplace_back(nullptr));
return;
}
auto position = iter - parameters.begin();
PushInputTensor(args[position], input_tensors);
}
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
size_t index, std::vector<tensor::TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(input_tensors);
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
const size_t position = iter - parameters.begin();
// If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
// and there is no need to input a tensor.
if (position >= args.size()) {
MS_LOG(DEBUG) << "Position out of args range, position value is " << position << " and args size is " << args.size()
<< ".";
(void)input_tensors->emplace_back(nullptr);
return;
}
ValuePtrList flatted_value_tuple_value;
FlattenValue(args[position], &flatted_value_tuple_value);
if (index >= flatted_value_tuple_value.size()) {
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";
}
auto input = flatted_value_tuple_value[index];
MS_EXCEPTION_IF_NULL(input);
auto tensor_input = input->cast<tensor::TensorPtr>();
input_tensors->push_back(tensor_input);
}
} // namespace
std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompilerInfo &graph_compiler_info,
const VectorRef &args) {
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
std::vector<std::vector<tensor::TensorPtr>> input_tensor_lists;
for (const auto &kernel_graph : graph_compiler_info.graphs_) {
std::vector<tensor::TensorPtr> input_tensors;
MS_EXCEPTION_IF_NULL(kernel_graph);
for (const auto &input_node : kernel_graph->input_nodes()) {
auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node);
if (element_pair.first) {
PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &input_tensors);
} else {
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
PushTensor(args, origin_parameters, front_node, &input_tensors);
}
}
(void)input_tensor_lists.emplace_back(input_tensors);
}
// Input tensors of the control node.
std::vector<tensor::TensorPtr> input_tensors;
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
// Get inputs of control node which come from the host actor.
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
for (const auto &parameter_with_index : control_node_parameters) {
const auto &parameter = parameter_with_index.first;
MS_EXCEPTION_IF_NULL(parameter);
const auto &abs = parameter->abstract();
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTuple>()) {
MS_LOG(DEBUG) << "Fetch input tensor for tuple parameter:" << parameter->DebugString() << " in control flow.";
PushTupleTensor(args, origin_parameters, parameter, parameter_with_index.second, &input_tensors);
} else {
PushTensor(args, origin_parameters, parameter, &input_tensors);
}
}
(void)input_tensor_lists.emplace_back(input_tensors);
return input_tensor_lists;
}
MindRTBackendBase::MindRTBackendBase(const std::string &backend_name, const std::string &device_name,
uint32_t device_id)
: Backend(backend_name), device_name_(device_name) {
root_graph_ = nullptr;
auto ms_context = MsContext::GetInstance();
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
auto &cut_list = pynative_mode ? GetControlOps() : GetMsNonlinearOps();
graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
graph_compiler_ = std::make_shared<GraphCompiler>();
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
device_context->Initialize();
device_id_ = device_context->device_context_key().device_id_;
#ifdef ENABLE_DEBUGGER
SetDebuggerInit();
#endif
runtime::GraphScheduler::GetInstance().Initialize();
}
void MindRTBackendBase::ProcessNotSupportCnode(const FuncGraphPtr &func_graph,
const mindspore::device::DeviceType &old_target,
const mindspore::device::DeviceType &new_target) const {
const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
for (const auto &node : all_nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!common::AnfAlgo::HasNodeAttr(kAttrNotSupportOpForDevice, cnode)) {
continue;
}
auto not_support_device = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrNotSupportOpForDevice);
if (device::GetDeviceTypeByName(not_support_device) != old_target) {
continue;
}
common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(device::GetDeviceNameByType(new_target)), node);
}
}
const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(graph_compiler_);
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << "Status record: start compile function graph: " << func_graph->ToString();
PROF_START(compile_func_graph);
auto root_graph = WrapPrimitives(func_graph);
MS_EXCEPTION_IF_NULL(root_graph);
root_graph_ = root_graph;
// Register a summary callback function, which is called in the final stages of summary.
graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
real_execution_mode_ = ms_execution_mode_;
func_graph->set_flag(kFlagPyNativeRunInGraph, real_execution_mode_ == kPynativeMode);
// Compile root graph.
graph_id_to_device_context_.clear();
func_graph_to_kernel_graph_ids_.clear();
control_nodes_.clear();
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
MS_EXCEPTION_IF_NULL(device_context);
bool all_support = device_context->PartitionGraph(func_graph);
if (all_support) {
auto run_mode = device_context->GetRunMode(func_graph);
if (run_mode == device::RunMode::kGraphMode && pynative::GraphAdapter::PyNativeEnableTaskSink(func_graph)) {
auto graph_id = graph_compiler_->CompileWholeGraphForGraphRunMode(func_graph, device_context);
graph_id_to_device_context_[graph_id] = device_context;
} else {
CompileSubGraph(func_graph, device::RunMode::kKernelMode);
}
} else {
ProcessNotSupportCnode(func_graph, device_context->GetDeviceType(), mindspore::device::DeviceType::kCPU);
CompileSubGraph(func_graph);
}
// Construct the graph compiler info.
auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
MS_EXCEPTION_IF_NULL(graph_compiler_info);
if (real_execution_mode_ == kGraphMode &&
((!graph_compiler_info->graphs_.empty()) || graph_compiler_info->control_nodes_.size() > 1)) {
// Transform graph to actor DAG, and schedule the actor DAG.
ParseControlNodes(*graph_compiler_info);
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
}
const ActorInfo &actor_info = graph_compiler_info->name_;
(void)actor_to_graph_compiler_info_.emplace(graph_compiler_info->name_, std::move(graph_compiler_info));
PROF_END(compile_func_graph);
if (ms_execution_mode_ != real_execution_mode_) {
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
}
MS_LOG(INFO) << "Status record: end compile function graph: " << func_graph->ToString()
<< ", produce actor: " << actor_info;
return actor_info;
}
void MindRTBackendBase::CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
auto root_graph = WrapPrimitives(func_graph);
MS_EXCEPTION_IF_NULL(root_graph);
CompileGraph(root_graph, run_mode);
MS_EXCEPTION_IF_NULL(root_graph->manager());
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
for (const auto &sub_graph : sub_graphs) {
if (sub_graph != func_graph && sub_graph != nullptr) {
CompileGraph(sub_graph, run_mode);
}
}
}
void MindRTBackendBase::CompileGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(graph_partition_);
MS_EXCEPTION_IF_NULL(graph_compiler_);
bool contain_multi_target = false;
// Split graph to segments.
const auto &segments = graph_partition_->Partition(func_graph, &contain_multi_target);
MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size();
// Foreach the segments to compile graph.
for (const auto &segment : segments) {
CompileGraph(segment, run_mode);
}
}
void MindRTBackendBase::CompileGraph(const GraphSegmentPtr &segment, device::RunMode run_mode) {
MS_EXCEPTION_IF_NULL(segment);
// Compile the normal nodes, which doesn't contain the cut node.
if (segment->nodes_.size() == 0) {
MS_LOG(EXCEPTION) << "The segments size is 0.";
}
if (!segment->is_cut_) {
MS_EXCEPTION_IF_NULL(segment->nodes_[0]);
MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->DebugString();
// Get the device context.
const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_});
MS_EXCEPTION_IF_NULL(device_context);
device_context->Initialize();
// Transform nodes to inputs and outputs.
FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// Compile graph.
auto graph_id =
graph_compiler_->CompileGraph(segment, outputs, device_context, run_mode, real_execution_mode_ == kPynativeMode);
graph_id_to_device_context_[graph_id] = device_context;
const auto &func_graph = segment->nodes_[0]->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph_to_kernel_graph_ids_.find(func_graph) == func_graph_to_kernel_graph_ids_.end()) {
(void)func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>{graph_id});
} else {
(void)func_graph_to_kernel_graph_ids_[func_graph].back().emplace_back(graph_id);
}
} else {
// Compile the cut node.
auto cut_node = segment->nodes_[0];
MS_EXCEPTION_IF_NULL(cut_node);
MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->DebugString();
control_nodes_.push_back(cut_node);
if (common::AnfAlgo::IsCallNode(cut_node) || common::AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitch) ||
common::AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitchLayer)) {
const auto &func_graph = cut_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
(void)func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>());
}
}
}
namespace {
void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(value);
MS_EXCEPTION_IF_NULL(outputs);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
for (size_t i = 0; i < value_tuple->size(); ++i) {
ValuePtr element = value_tuple->value()[i];
MS_EXCEPTION_IF_NULL(element);
if (element->isa<tensor::Tensor>()) {
auto tensor = element->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
outputs->emplace_back(tensor);
} else if (element->isa<ValueTuple>()) {
VectorRef tuple;
TensorValueToVector(element, &tuple);
outputs->emplace_back(tuple);
}
}
} else if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
outputs->emplace_back(tensor);
}
}
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const VectorRef &args, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(graph_output);
MS_EXCEPTION_IF_NULL(outputs);
if (graph_output->isa<ValueNode>()) {
MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
VectorRef output_tmp;
ValuePtr value = GetValueNode(graph_output);
TensorValueToVector(value, &output_tmp);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
outputs->emplace_back(output_tmp);
} else if (value->isa<tensor::Tensor>()) {
*outputs = output_tmp;
} else {
MS_LOG(INFO) << "Graph output is empty!";
}
return true;
}
if (graph_output->isa<Parameter>()) {
MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
// Find the right parameter as ret_val.
auto func_graph = graph_output->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto params = func_graph->parameters();
if (args.size() != params.size()) {
MS_LOG(EXCEPTION) << "Input size " << args.size() << " not equal to graph input size " << params.size();
}
auto it = std::find(params.begin(), params.end(), graph_output);
if (it == params.end()) {
MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
}
size_t index = it - params.cbegin();
if (index >= args.size()) {
MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size();
}
outputs->emplace_back(args[index]);
return true;
}
return false;
}
} // namespace
void MindRTBackendBase::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs,
const FuncGraphPtr &root_graph) {
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
bool is_embedding_cache_server = false;
#ifdef WITH_BACKEND
is_embedding_cache_server = ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server();
#endif
if (need_contruct_output) {
// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.
if (!is_embedding_cache_server) {
actor_set->output_actor_->UpdateOutputDeviceAddress();
}
// Fetch outputs.
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
auto &output_tensors = actor_set->output_actor_->outputs();
if (!output_tensors.empty()) {
size_t output_position = 0;
ConstructOutputs(root_graph->output(), output_tensors, &output_position, outputs);
}
}
}
void MindRTBackendBase::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(root_graph_);
if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) {
return;
}
const auto &context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return;
}
// Open abstract_lock for dynamic_shape
AnfUtils::OpenAbstractLock();
MS_LOG(INFO) << "Status record: start run actor: " << actor_info;
// Fetch the graph compiler info.
const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
if (graph_iter == actor_to_graph_compiler_info_.end()) {
MS_LOG(EXCEPTION) << "Can't find the graph compiler info.";
}
MS_EXCEPTION_IF_NULL(graph_iter->second);
const auto &graph_compiler_info = *(graph_iter->second);
// For pynative and graph mix execution.
WaitTaskFinish();
// Run in the pynative mode.
MS_EXCEPTION_IF_NULL(outputs);
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
if (real_execution_mode_ == kPynativeMode) {
RunGraphByCondition(actor_info, graph_compiler_info, args, outputs);
return;
}
auto input_tensors = GetRunGraphInputs(graph_compiler_info, args);
// Release python gil.
mindspore::ScopedLongRunning long_running;
// Run actor DAG.
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
MS_EXCEPTION_IF_NULL(actor_set);
runtime::GraphScheduler::GetInstance().Run(actor_set, input_tensors);
MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->Summary(graph_compiler_info.graphs_);
ConstructOutputs(actor_set, outputs, root_graph_);
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
// Close abstract_lock for dynamic_shape
AnfUtils::CloseAbstractLock();
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
}
BaseRef MindRTBackendBase::ConstructOutputByAbstract(const abstract::AbstractBasePtr &abstract,
const std::vector<tensor::TensorPtr> &output_tensors,
size_t *output_position) {
MS_EXCEPTION_IF_NULL(abstract);
MS_EXCEPTION_IF_NULL(output_position);
size_t outputs_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
if (*output_position + outputs_num > output_tensors.size()) {
MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position << " need:" << outputs_num
<< " total:" << output_tensors.size();
}
VectorRef outputs;
if (!abstract->isa<abstract::AbstractTuple>()) {
(*output_position)++;
return output_tensors[(*output_position) - 1];
}
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
for (const auto &sub_abstract : sub_abstracts) {
MS_EXCEPTION_IF_NULL(sub_abstract);
outputs.emplace_back(ConstructOutputByAbstract(sub_abstract, output_tensors, output_position));
}
return outputs;
}
void MindRTBackendBase::ConstructOutputs(const AnfNodePtr &output_node,
const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position,
VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(output_node);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(output_position);
const PrimitiveSet expand_prims{
prim::kPrimMakeTuple,
prim::kPrimMakeCSRTensor,
prim::kPrimMakeCOOTensor,
prim::kPrimMakeRowTensor,
};
// The MakeTuple/MakeSaprse node need expand and recurse.
if (IsOneOfPrimitiveCNode(output_node, expand_prims)) {
auto make_tuple = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
VectorRef make_tuple_output;
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
ConstructOutputs(make_tuple->input(i), output_tensors, output_position, &make_tuple_output);
}
outputs->emplace_back(std::move(make_tuple_output));
return;
}
// The depend node need get the real node.
if (common::AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
auto depend_node = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_node);
ConstructOutputs(depend_node->input(kRealInputIndexInDepend), output_tensors, output_position, outputs);
return;
}
auto outputs_num = common::AnfAlgo::GetOutputTensorNum(output_node);
// The value node uses the value to be output, to avoid the host memory of value free due to value node destruction.
if (output_node->isa<ValueNode>()) {
auto value = output_node->cast<ValueNodePtr>()->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
outputs->emplace_back(value);
(*output_position) += CountValueNum(value->cast<ValueTuplePtr>());
} else if (outputs_num != 0) {
outputs->emplace_back(value);
(*output_position) += outputs_num;
}
// The empty value node return the empty VectorRef.
return;
}
if (common::AnfAlgo::IsCallNode(output_node)) {
auto abstract = output_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
outputs->emplace_back(ConstructOutputByAbstract(abstract, output_tensors, output_position));
return;
}
auto &output_abstract = output_node->abstract();
MS_EXCEPTION_IF_NULL(output_abstract);
// Wrap output to VectorRef if the output is tuple.
if (output_abstract->isa<abstract::AbstractTuple>()) {
VectorRef output_tuple;
for (size_t i = 0; i < outputs_num; ++i) {
if (*output_position >= output_tensors.size()) {
MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position;
}
output_tuple.emplace_back(std::move(output_tensors[*output_position]));
++(*output_position);
}
outputs->emplace_back(std::move(output_tuple));
} else {
for (size_t i = 0; i < outputs_num; ++i) {
if (*output_position >= output_tensors.size()) {
MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position;
}
outputs->emplace_back(std::move(output_tensors[*output_position]));
++(*output_position);
}
}
}
#ifdef ENABLE_DEBUGGER
void MindRTBackendBase::SetDebuggerInit() {
auto debugger_ = Debugger::GetInstance();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
}
#endif
std::shared_ptr<GraphCompilerInfo> MindRTBackendBase::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {
MS_EXCEPTION_IF_NULL(root_graph);
MS_EXCEPTION_IF_NULL(graph_compiler_);
std::vector<KernelGraphPtr> graphs;
std::vector<DeviceContext *> device_contexts;
std::string name = "kernel_graph";
size_t graph_index = 0;
for (const auto &graph_id_to_context : graph_id_to_device_context_) {
(void)graphs.emplace_back(graph_compiler_->Fetch(graph_id_to_context.first));
(void)device_contexts.emplace_back(graph_id_to_context.second);
if (graph_index == 0) {
(void)name.append("_").append(std::to_string(graph_id_to_context.first));
} else if (graph_index == graph_id_to_device_context_.size() - 1) {
(void)name.append("-").append(std::to_string(graph_id_to_context.first));
}
++graph_index;
}
auto parser = std::make_shared<ControlNodeParser>();
runtime::KernelMapPosition outputs_order;
const auto &root_output =
common::AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
size_t position = 0;
auto outputs = common::AnfAlgo::GetAllOutputWithIndex(root_output);
size_t outputs_num = outputs.size();
for (const auto &output : outputs) {
if (outputs_order.count(output) == 0) {
outputs_order[output] = {position++};
} else {
(void)outputs_order[output].emplace_back(position++);
}
}
std::vector<std::vector<int64_t> *> tensors_mask;
std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
auto strategy = runtime::GraphExecutionStrategy::kPipeline;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) != kOptimizeO0) {
strategy = runtime::GraphExecutionStrategy::kPipelineWithExecutionOrder;
}
return std::make_shared<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
root_graph->parameters(), parser, outputs_order, outputs_num, name, false,
strategy);
}
void MindRTBackendBase::ParseControlNodes(const GraphCompilerInfo &graph_compile_info) {
FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs;
for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
const auto &func_graph = func_graph_to_kernel_graph_ids.first;
for (const auto &sub_kernel_graphs_ids : func_graph_to_kernel_graph_ids.second) {
std::vector<KernelGraphPtr> kernel_graphs;
for (const auto &graph_id : sub_kernel_graphs_ids) {
const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
(void)kernel_graphs.emplace_back(kernel_graph);
}
(void)func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graphs);
}
}
graph_compile_info.control_node_parser_->Parse(control_nodes_, graph_compile_info.graphs_,
graph_compile_info.device_contexts_, root_graph_,
func_graph_to_kernel_graphs);
}
} // namespace compile
} // namespace mindspore

View File

@ -0,0 +1,146 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_VM_BACKENDBASE_H_
#define MINDSPORE_CCSRC_VM_BACKENDBASE_H_
#include <list>
#include <memory>
#include <string>
#include <map>
#include <set>
#include <utility>
#include <vector>
#include "utils/hash_map.h"
#include "ir/anf.h"
#include "backend/common/session/session_basic.h"
#include "runtime/hardware/device_context.h"
#include "backend/graph_compiler/segment_runner.h"
#include "runtime/graph_scheduler/actor/actor_set.h"
namespace mindspore {
namespace compile {
using GraphOutputInfo = session::GraphOutputInfo;
using DeviceContext = device::DeviceContext;
using ActorInfo = runtime::ActorInfo;
using GraphCompiler = runtime::GraphCompiler;
using GraphCompilerInfo = runtime::GraphCompilerInfo;
using ControlNodeParser = runtime::ControlNodeParser;
using FuncGraphToKernelGraphGroup = runtime::FuncGraphToKernelGraphGroup;
using ControlNodeParserPtr = runtime::ControlNodeParserPtr;
using KernelWithIndex = session::KernelWithIndex;
enum SwitchCondStatus {
kCondOk = 0,
kCondAlreadyRun,
};
class BACKEND_EXPORT Backend {
public:
explicit Backend(const std::string &name);
virtual ~Backend() = default;
LinkFuncType convert_fn() { return convert_fn_; }
std::string name() { return name_; }
virtual bool GetCond(const BaseRef &c, bool *value);
virtual bool GetIndex(const BaseRef &c, int64_t *value);
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
virtual void SetDebugger() {}
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; }
protected:
std::string name_;
LinkFuncType convert_fn_;
bool is_multi_graph_sink_;
};
void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs);
std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompilerInfo &graph_compiler_info,
const VectorRef &args);
class BACKEND_EXPORT MindRTBackendBase : public Backend {
public:
MindRTBackendBase(const std::string &backend_name, const std::string &device_name, uint32_t device_id);
~MindRTBackendBase() override = default;
// The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs, It will traverse
// all sub graphs to call CompileGraph.
const ActorInfo &CompileGraphs(const FuncGraphPtr &func_graph);
// Run Graph in the graph mode.
void RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs);
#ifdef ENABLE_DEBUGGER
void SetDebuggerInit();
#endif
// Get the device target.
std::string GetDeviceTarget() { return device_name_; }
virtual void WaitTaskFinish() const {}
virtual void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const VectorRef &args, VectorRef *outputs) {}
protected:
// The parameter func_graph is a graph, it can be either a root graph or a sub graph,
// The result of graph compiler is stored in graph_id_to_device_context_ and control_nodes_.
void CompileGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode);
// Compile the kernel graph by the segment which is from the function graph partition.
void CompileGraph(const GraphSegmentPtr &segment, device::RunMode run_mode);
void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph);
// Restore the outputs tuple by the origin funcGraph output node and output tensors.
void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors,
size_t *output_position, VectorRef *outputs);
// In the control flow, the output of the call node needs to be created by abstract.
BaseRef ConstructOutputByAbstract(const abstract::AbstractBasePtr &abstract,
const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position);
// Construct the GraphCompilerInfo by the compilation results of graph, used in Graph mode.
std::shared_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph);
void ParseControlNodes(const GraphCompilerInfo &graph_compile_info);
// When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several
// node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to
// the corresponding device_context.
std::map<GraphId, DeviceContext *> graph_id_to_device_context_;
// Funcgraph will be cut into multiple kernel graphs, and the map is used to save the correspondence.
// The kernel graphs which not cut by control flow are placed in the same group.
std::map<FuncGraphPtr, std::vector<std::vector<GraphId>>> func_graph_to_kernel_graph_ids_;
std::map<GraphInfo, DeviceContext *> graph_info_to_device_context_;
std::vector<AnfNodePtr> control_nodes_;
mindspore::HashMap<ActorInfo, std::shared_ptr<GraphCompilerInfo>> actor_to_graph_compiler_info_;
FuncGraphPtr root_graph_;
GraphPartitionPtr graph_partition_;
std::shared_ptr<GraphCompiler> graph_compiler_;
std::string device_name_;
uint32_t device_id_;
int ms_execution_mode_{kGraphMode};
int real_execution_mode_{kGraphMode};
void CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode = device::RunMode::kUnknown);
void ProcessNotSupportCnode(const FuncGraphPtr &func_graph, const device::DeviceType &old_target,
const device::DeviceType &new_target) const;
};
} // namespace compile
} // namespace mindspore
#endif