!26045 Delete useless interface in control node parser.

Merge pull request !26045 from gaoyong10/runtime_second8
This commit is contained in:
i-robot 2021-11-10 03:00:02 +00:00 committed by Gitee
commit 4eef5e5c13
14 changed files with 879 additions and 1485 deletions

View File

@ -132,16 +132,6 @@ bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
return false;
}
bool IsGatherActor(const AnfNodePtr &front_node,
const std::unordered_map<std::string, OpActor<DeviceTensor> *> &actor_name_to_actor) {
MS_EXCEPTION_IF_NULL(front_node);
if (front_node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(front_node->cast<ParameterPtr>())) &&
(front_node->func_graph() != nullptr) && (actor_name_to_actor.count(front_node->func_graph()->ToString()) > 0)) {
return true;
}
return false;
}
bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
MS_EXCEPTION_IF_NULL(dst_device_tensor);
MS_EXCEPTION_IF_NULL(src_device_tensor);

View File

@ -133,10 +133,6 @@ bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph);
// Judge whether the device tensor of the node is persistent or not.
bool IsPersistentDeviceTensor(const AnfNodePtr &node);
// Judge whether the front node is in a gather actor.
bool IsGatherActor(const AnfNodePtr &front_node,
const std::unordered_map<std::string, OpActor<DeviceTensor> *> &actor_name_to_actor);
// Copy data from src_device_tensor to dst_device_tensor.
bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);

View File

@ -177,14 +177,157 @@ void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) {
ofs << "\n";
}
void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
const auto &output_data_arrows = actor->output_data_arrows();
if (output_data_arrows.size() > 0) {
ofs << "\t\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
for (const auto &data_arrow : output_data_arrows) {
MS_EXCEPTION_IF_NULL(data_arrow);
ofs << "\t\t\t\tfrom_output_index:" << data_arrow->from_output_index_
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
<< "\n";
}
}
const auto &output_control_arrows = actor->output_control_arrows();
if (output_control_arrows.size() > 0) {
ofs << "\t\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
for (const auto &aid : output_control_arrows) {
ofs << "\t\t\t\tto_actor_name:" << aid.Name() << "\n";
}
}
const auto &output_partial_arrows = actor->output_partial_arrows();
if (output_partial_arrows.size() > 0) {
ofs << "\t\t\toutput_partial_arrows:" << output_partial_arrows.size() << "\n ";
for (const auto &partial_arrow : output_partial_arrows) {
MS_EXCEPTION_IF_NULL(partial_arrow);
ofs << "\t\t\t\tfrom_output_index:" << partial_arrow->from_output_index_
<< "\tto_actor_name:" << partial_arrow->to_op_id_.Name()
<< "\tto_input_index:" << partial_arrow->to_input_index_ << "\n";
}
}
const auto &output_branch_id_arrows = actor->output_branch_id_arrows();
if (output_branch_id_arrows.size() > 0) {
ofs << "\t\t\toutput_branch_id_arrows:" << output_branch_id_arrows.size() << "\n ";
for (const auto &aid : output_branch_id_arrows) {
ofs << "\t\t\t\tto_actor_name:" << aid.Name() << "\n";
}
}
}
void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\t\ttactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs);
}
void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs);
const auto &output_data_with_branch_id_arrows = actor->output_data_with_branch_id_arrows();
if (output_data_with_branch_id_arrows.size() > 0) {
ofs << "\t\t\toutput_data_with_branch_id_arrows:" << output_data_with_branch_id_arrows.size() << "\n ";
for (const auto &output_data_with_branch_id_arrow : output_data_with_branch_id_arrows) {
ofs << "\t\t\t\tbranch funcgraph:" << output_data_with_branch_id_arrow.first->ToString() << "\n";
for (const auto &arrow : output_data_with_branch_id_arrow.second) {
ofs << "\t\t\t\t\tto actor:" << arrow << "\n";
}
}
}
}
void DumpEntranceActor(const EntranceActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs);
}
void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs);
const auto &output_branch_data_arrows = actor->output_branch_data_arrows();
if (output_branch_data_arrows.size() > 0) {
ofs << "\t\t\toutput_branch_data_arrows:" << output_branch_data_arrows.size() << "\n ";
for (const auto &output_branch_data_arrow : output_branch_data_arrows) {
ofs << "\t\t\t\tbranch id:" << output_branch_data_arrow.first << "\n";
for (const auto &arrow : output_branch_data_arrow.second) {
MS_EXCEPTION_IF_NULL(arrow);
ofs << "\t\t\t\t\tfrom_output_index:" << arrow->from_output_index_
<< "\tto_actor_name:" << arrow->to_op_id_.Name() << "\tto_input_index:" << arrow->to_input_index_ << "\n";
}
}
}
const auto &output_branch_partial_arrows = actor->output_branch_partial_arrows();
if (output_branch_partial_arrows.size() > 0) {
ofs << "\t\t\toutput_branch_partial_arrows:" << output_branch_partial_arrows.size() << "\n ";
for (const auto &output_branch_partial_arrow : output_branch_partial_arrows) {
ofs << "\t\t\t\tbranch id:" << output_branch_partial_arrow.first << "\n";
for (const auto &arrow : output_branch_partial_arrow.second) {
MS_EXCEPTION_IF_NULL(arrow);
ofs << "\t\t\t\t\tfrom_output_index:" << arrow->from_output_index_
<< "\tto_actor_name:" << arrow->to_op_id_.Name() << "\tto_input_index:" << arrow->to_input_index_ << "\n";
}
}
}
const auto &output_branch_control_arrows = actor->output_branch_control_arrows();
if (output_branch_control_arrows.size() > 0) {
ofs << "\t\t\toutput_branch_control_arrows:" << output_branch_control_arrows.size() << "\n ";
for (const auto &output_branch_control_arrow : output_branch_control_arrows) {
ofs << "\t\t\t\tbranch id:" << output_branch_control_arrow.first << "\n";
for (const auto &arrow : output_branch_control_arrow.second) {
ofs << "\t\t\t\t\tto actor:" << arrow << "\n";
}
}
}
}
void DumpStackActor(const StackActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs);
}
void DumpSwitchActors(const std::vector<SwitchActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n\t[Switch actors:" << actors.size() << "]\n";
for (const auto &switch_actor : actors) {
DumpSwitchActor(switch_actor.get(), ofs);
}
}
void DumpGatherActors(const std::vector<GatherActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n\t[Gather actors:" << actors.size() << "]\n";
for (const auto &gather_actor : actors) {
DumpGatherActor(gather_actor.get(), ofs);
}
}
void DumpEntranceActors(const std::vector<EntranceActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n\t[Entrance actors:" << actors.size() << "]\n";
for (const auto &entrance_actor : actors) {
DumpEntranceActor(entrance_actor.get(), ofs);
}
}
void DumpExitActors(const std::vector<ExitActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n\t[Exit actors:" << actors.size() << "]\n";
for (const auto &exit_actor : actors) {
DumpExitActor(exit_actor.get(), ofs);
}
}
void DumpStackActors(const std::vector<StackActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n\t[Stack actors:" << actors.size() << "]\n";
for (const auto &stack_actor : actors) {
DumpStackActor(stack_actor.get(), ofs);
}
}
} // namespace
@ -281,18 +424,17 @@ void DumpCopyActors(const std::vector<CopyActorPtr> &actors, std::ofstream &ofs)
}
}
void DumpGatherActors(const std::vector<GatherActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n[Gather actors:" << actors.size() << "]\n";
for (const auto &gather_actor : actors) {
DumpGatherActor(gather_actor.get(), ofs);
void DumpControlActors(const ControlActorSetPtr &control_actor_set, std::ofstream &ofs) {
ofs << "\n\n[Control actors]\n";
if (control_actor_set == nullptr) {
return;
}
}
void DumpSwitchActors(const std::vector<SwitchActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n[Switch actors:" << actors.size() << "]\n";
for (const auto &switch_actor : actors) {
DumpSwitchActor(switch_actor.get(), ofs);
}
DumpSwitchActors(control_actor_set->switch_actors_, ofs);
DumpGatherActors(control_actor_set->gather_actors_, ofs);
DumpEntranceActors(control_actor_set->entrance_actors_, ofs);
DumpExitActors(control_actor_set->exit_actors_, ofs);
DumpStackActors(control_actor_set->stack_actors_, ofs);
}
} // namespace runtime
} // namespace mindspore

View File

@ -30,8 +30,13 @@
#include "runtime/framework/actor/super_kernel_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/copy_actor.h"
#include "runtime/framework/actor/control_flow/control_actor.h"
#include "runtime/framework/actor/control_flow/switch_actor.h"
#include "runtime/framework/actor/control_flow/gather_actor.h"
#include "runtime/framework/actor/control_flow/entrance_actor.h"
#include "runtime/framework/actor/control_flow/exit_actor.h"
#include "runtime/framework/actor/control_flow/stack_actor.h"
#include "runtime/framework/control_node_scheduler.h"
namespace mindspore {
namespace runtime {
@ -43,8 +48,7 @@ void DumpKernelActors(const std::vector<KernelActorPtr> &actors, std::ofstream &
void DumpSuperKernelActors(const std::vector<SuperKernelActorPtr> &actors, std::ofstream &ofs);
void DumpNoInputKernelActors(const std::vector<AbstractActorPtr> &actors, std::ofstream &ofs);
void DumpCopyActors(const std::vector<CopyActorPtr> &actors, std::ofstream &ofs);
void DumpGatherActors(const std::vector<GatherActorPtr> &actors, std::ofstream &ofs);
void DumpSwitchActors(const std::vector<SwitchActorPtr> &actors, std::ofstream &ofs);
void DumpControlActors(const ControlActorSetPtr &control_actor_set, std::ofstream &ofs);
} // namespace runtime
} // namespace mindspore

View File

@ -75,6 +75,10 @@ void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) {
}
void ExitActor::CopyDeviceAddress() {
// If node is not empty, it is the exit of funcgraph, no need to create device address.
if (node_ != nullptr) {
return;
}
std::vector<DeviceTensor *> new_device_tensors;
for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
auto input_device_tensor = input_device_tensors_[i];

View File

@ -28,6 +28,7 @@ void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
ControlActor::FetchInput(context);
output_partial_ = input_partials_[0];
MS_EXCEPTION_IF_NULL(output_partial_.first);
// Put other real parameter in partial.
for (const auto &device_tensor : input_device_tensors_) {

View File

@ -459,9 +459,9 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
}
// In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor.
void DataPrepareActor::PrepareDataForControlWeightNode(
const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, const DeviceContext *device_context,
const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &host_parameter_to_weights,
void DataPrepareActor::PrepareDataForControlWeightNode(const AnfNodePtr &node, const AnfNodePtr &front_node,
const TensorPtr &tensor, const DeviceContext *device_context,
const HostParameterToWeight &host_parameter_to_weights,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(front_node);
@ -516,12 +516,12 @@ void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeP
if (IsPersistentDeviceTensor(input_node)) {
const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters();
const auto &iter = front_to_backend_parameters.find(input_node);
if (iter == front_to_backend_parameters.end()) {
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
<< AnfAlgo::GetNodeDebugString(input_node);
}
const auto &node_with_context = iter->second;
PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
const auto &node_with_context = iter->second.begin();
PrepareDataForControlWeightNode(node_with_context->first, input_node, input_tensor, node_with_context->second,
control_node_parser->host_parameter_to_weights(), context);
}
}

View File

@ -93,9 +93,9 @@ class DataPrepareActor : public DebugAwareActor {
std::vector<TensorPtr> *const host_tensors,
OpContext<DeviceTensor> *const context);
// In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor.
void PrepareDataForControlWeightNode(
const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, const DeviceContext *device_context,
const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &host_parameter_to_weights,
void PrepareDataForControlWeightNode(const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor,
const DeviceContext *device_context,
const HostParameterToWeight &host_parameter_to_weights,
OpContext<DeviceTensor> *const context);
const GraphCompilerInfo *graph_compiler_info_;

File diff suppressed because it is too large Load Diff

View File

@ -21,6 +21,7 @@
#include <string>
#include <memory>
#include <set>
#include <queue>
#include <map>
#include <utility>
#include <unordered_map>
@ -56,102 +57,55 @@ constexpr size_t kSingleControlNode = 1;
const char kEntranceActorNameSuffix[] = "_EntranceActor";
const char kStackActorNameSuffix[] = "_StackActor";
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, DeviceContext *>>>;
using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
using FuncGraphToKernelGraph = std::unordered_map<FuncGraphPtr, std::vector<KernelGraphPtr>>;
using FuncGraphToParameter = std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>>;
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>>;
using NodeWithDeviceContext = std::set<std::pair<AnfNodePtr, DeviceContext *>>;
using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
using FormalToRealParameter = std::unordered_map<AnfNodePtr, std::set<KernelWithIndex>>;
using RealToFormalParameter = std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>>;
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
using FrontNodeToKernelGraph = std::unordered_map<AnfNodePtr, KernelGraphPtr>;
// Check if the call node is the input of another call node.
bool IsSubCallNode(const AnfNodePtr &node);
// Recursive interface, find the real output of funcgraph called by call node.
AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes);
using FuncGraphCallRelation = std::unordered_map<FuncGraphPtr, std::vector<std::set<FuncGraphPtr>>>;
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
// it is determined whether it is a weight.
bool HasAbstractRef(const AnfNodePtr &node);
// Recursive interface, get the funcgraph which the node belongs, if the node has a front node, return the funcgraph
// which the front node belongs, if not, find the funcgraph which the input of the node belongs.
FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node);
// Recursive interface, get the number of output nodes of funcgraph called by call node.
size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes);
// Get front node by backend node.
AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
// corresponding cnode.
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);
// Get the funcgraph to which the node belongs.
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
// Find all funcgraphs that the call node will call.
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node);
// Get parameters in kernel graph.
std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &graph);
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
class ControlNodeParser {
public:
// Parse the control node and put the results of the parsing into member variables.
void Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph);
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph,
const FuncGraphToKernelGraph &func_graph_to_kernel_graphs);
bool IsInited() { return is_inited_; }
// Check whether there is a call node in the front input nodes of the kernel graph.
bool IsCallInputKernelGraph(const KernelGraphPtr &graph);
// Check whether the data arrow of the kernel actor needs to be connected to the control actor.
// There are two situations:
// 1. In control flow, the parameter input needs to be connected to the entrance actor of the funcgraph.
// 2. In the kernel graph with call node input, the data arrow needs to be connected to the stack actor.
bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node);
const std::vector<AnfNodePtr> &control_node_parameters() const { return control_node_parameters_; }
const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; }
const HostParameterToWeight &host_parameter_to_weights() const { return host_parameter_to_weights_; }
const NodeWithDeviceContext &front_value_nodes() const { return front_value_nodes_; }
// Get the output of funcgraph, usually there is only one output node, In the control flow, there are
// multiple branch outputs, there will be multiple output nodes.
std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
// Get all possible input nodes of the output node. When the switch actor is the output, it need to send the node
// which device address belongs, so switch actor need to get all the possible nodes.
std::set<KernelWithIndex> FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output);
// Get the device context corresponding to the value node.
DeviceContext *GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node);
// Get the branch id corresponding to call node.
int GetBranchIDByCallNode(const AnfNodePtr &call_node);
// Get the number of calls to funcgraph
size_t GetCallNumByFuncGraph(const FuncGraphPtr &func_graph);
// Get all possible input nodes of the output node. When the gather actor is the output, it need to send the node
// which device address belongs, so gather actor need to get all the possible nodes.
std::vector<KernelWithIndex> GetBackendInputByParameter(const AnfNodePtr &parameter);
// Check whether there is a call node in the front input nodes of the kernel graph.
bool IsCallInputKernelGraph(const KernelGraphPtr &graph);
// Check whether the kernel actor belongs to the root graph.
// In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output should
// be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be sent to the
// output switch actor.
bool IsKernelInRootFuncGraph(const AnfNodePtr &kernel);
// Get the backend node corresponding to the weight node in the subgraph.
AnfNodePtr FetchBackendNodebyWeightNode(const AnfNodePtr &node);
KernelWithIndex GetBackendKernelByFrontKernel(const KernelWithIndex &front_node_with_index) {
return front_to_backend_kernels_[front_node_with_index].first;
}
AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node);
KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index);
// Fetch all funcgraphs that the call node may call.
const std::set<FuncGraphPtr> &FetchFuncGraphbyCallNode(const AnfNodePtr &control_node);
// Fetch the branch id corresponding to funcgraph.
int FetchBranchIDByCallNode(const AnfNodePtr &call_node);
// Fetch the funcgraph which the kernel belongs.
FuncGraphPtr FetchKernelGraphByFrontNode(const AnfNodePtr &kernel);
// Fetch the backend kernel of front node.
KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index);
private:
friend class GraphScheduler;
@ -160,134 +114,120 @@ class ControlNodeParser {
// value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
// them separately during initialization.
// The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts);
// Create branch id for all subgraphs in the control flow.
void CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
// Find all value nodes in the switch recursively.
void FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, std::vector<AnfNodePtr> *value_nodes);
// Fetch all the relationships between front parameters and backend parameters.The front parameters
void FetchFrontValueNode();
// Create branch id for all call node in the control flow.
void CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes);
// Parse all the relationships between front parameters and backend parameters.The front parameters
// include two parts:
// 1. The parameter from kernel graph.
// 2. The parameter from control nodes.
void FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters);
// Get the relationship between the front and backend of the executable kernel in all kernel graphs.
void FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
void ParseFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts);
// Get inputs of control node which come from the host actor. These inputs generally come from the partial
// nodes and call nodes of the root funcgraph.
std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes,
DeviceContext *device_context);
// Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node,
// and the input of the call node is the input parameter of the corresponding funcgraph.
void FetchFuncGraphToParameter(const std::vector<AnfNodePtr> &control_nodes);
// Get all the front weight parameters related to the weight in the host parameter.
void FetchHostParameterToWeight(const RealToFormalNode &real_to_formal_front_parameters);
// The relationship between front parameters indicates that the parameter is directly used as the input of the
// funcgraph. There are two situations:
// 1. The parameter is used as the input of the call node,
// 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the
// subsequent call node.
void FetchFrontToFrontParameter(const std::vector<AnfNodePtr> &control_nodes,
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
// Get the number of calls to all subgraphs in the whole funcgraph.
void FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes);
void ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes);
// Recursively get all the real parameters corresponding to the formal parameters.
void ParseAllRealParameterByFormalParameter(const AnfNodePtr &formal_parameter,
const FormalToRealParameter &formal_to_real_parameters,
std::set<KernelWithIndex> *total_real_parameters,
std::set<AnfNodePtr> *invalid_real_parameter);
// Parse the device context of the control node. In a heterogeneous scenario, different device contexts need to be
// copied between different device memories. The analysis steps:
// 1. Get the device context of the funcgraph parameter according to the device type of the kernel in the funcgraph.
// 2. Determine the type of device context output by funcgraph according to the call relationship of funcgrpah.
void ParseDeviceContext(const std::vector<AnfNodePtr> &control_nodes,
const std::vector<KernelGraphPtr> &kernel_graphs,
const std::vector<DeviceContext *> &device_contexts,
const FuncGraphToKernelGraph &func_graph_to_kernel_graphs);
void ParseDeviceContextForFuncGraph(const std::vector<AnfNodePtr> &control_nodes,
const std::vector<KernelGraphPtr> &kernel_graphs,
const std::vector<DeviceContext *> &device_contexts,
const FuncGraphToKernelGraph &func_graph_to_kernel_graphs);
void ParseDeviceContextForControlNode(const DeviceContext *default_context);
// In the actor model, when the funcgraph comes to an end temporarily, the exit of the funcgraph needs to notify
// the entrance actor so that it can process next parameters. This is used to obtain the nodes corresponding to all
// actors in the funcgraph that need to send control messages to the entrance.
// These node are control nodes without control node input in the topological sort of the funcgraph.
void ParseFirstControlNodeForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
// Parse all funcgraphs that call nodes may call.
void ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
// Get the relationship between the front and backend of the executable kernel in all kernel graphs.
void FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts);
void FetchFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs);
// nodes and call nodes of the root funcgraph.
void FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
// Get all the front weight parameters related to the weight in the host parameter.
void FetchHostParameterToWeight();
// Get all the kernel graphs where the input node has a call node.
void FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts);
// Get the relationship of all real and formal nodes in the whole funcgraph.
void FetchBackendInputNode(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters);
// Get the relationship of all real and formal parameters in the whole funcgraph.
void FetchBackendParameterNode(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters,
FrontToBackendNodeWithContext *front_to_backend_parameters);
// Get all possible input node of real parameter.
void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter,
const FrontToBackendNodeWithContext &front_to_backend_parameters);
// Recursive interface, get all Backend node by front_output.
void FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, std::set<AnfNodePtr> *call_nodes,
std::set<AnfNodePtr> *switch_nodes, std::set<KernelWithIndex> *results);
// Get the dependency between kernel and call node in auto monad.
void FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes);
// Fetch the formal parameter in root graph by parameters in subgraph.
AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node);
// In control flow, funcgraph will be cut into multiple kernel graphs for execution, and this relationship is recorded
// in this map.
FuncGraphToKernelGraph func_graph_to_kernel_graphs_;
// The kernel graph to which the front node belongs after the funcgraph is cut.
FrontNodeToKernelGraph front_node_to_kernel_graph_;
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
FrontToBackendNodeWithContext front_to_backend_parameters_;
// The relationship between all real parameters and formal parameters in the entire func_graph.
// In control flow, the control actor will be the output actor. Since the actor needs to send the node to the output
// actor, it is necessary to save all the real parameters corresponding to the formal parameters in the control actor.
// When the control actor receives the device address, it can find the corresponding input node.
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> formal_to_real_parameters_;
// Relationship between the front and backend of the executable kernel in all kernel graphs.
FrontToBackendKernelWithContext front_to_backend_kernels_;
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
// the input node of gather.
FuncGraphToParameter func_graph_to_parameters_;
// Relationship between formal parameters and real parameters.
FormalToRealParameter formal_to_real_parameters_;
RealToFormalParameter real_to_formal_parameters_;
// The relationship between the valuenode inputs of the call node and the backend parameter
std::map<KernelWithIndex, std::pair<AnfNodePtr, DeviceContext *>> call_node_to_backend_parameters_;
// Branch id of call node.
// Branch id of funcgraph.
// In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to
// different places. Therefore, a branch id is created for each call node. When funcgraph is called, the branch
// id needs to be sent to the entrance actor corresponding to the funcgraph, and then send the branch id to its
// output switch actor.
// different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch
// id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id
// to its output switch actor.
std::unordered_map<AnfNodePtr, int> call_node_to_branch_id_;
std::unordered_map<AnfNodePtr, std::set<FuncGraphPtr>> call_node_to_func_graphs_;
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
// When initializing the weights, all related weights need to be recorded as the same device tensor.
HostParameterToWeight host_parameter_to_weights_;
std::unordered_map<AnfNodePtr, AnfNodePtr> sub_front_node_to_root_front_node_;
// The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the
// input of the control node.
NodeWithDeviceContext front_value_nodes_;
// The front value node saves all parameters that are not in the kernel graph. These nodes are generally the
// output of subgraph, or the switch condition node.
NodeWithDeviceContext front_parameters_;
// Parameters of control node which come from the host actor.
std::vector<AnfNodePtr> control_node_parameters_;
// The number of calls to func_graph.
std::unordered_map<FuncGraphPtr, size_t> func_graph_to_call_num_;
// In control flow, funcgraph will be divided into multiple kernel graphs. This map records this correspondence.
FuncGraphToKernelGraph func_graph_to_kernel_graphs_;
// In control flow, if there is a call node in the funcgraph, it means that when the funcgraph executes to the call,
// it needs to jump to another funcgraph. At this time, the funcgraph needs to process other real parameters, so
// these nodes need to send control arrows to the entrance actor to tell it to continue processing other parameters,
// these nodes are recorded in this map.
std::unordered_map<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_;
// The kernel graph of call exists in the front input node.
// In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor
// is created for kernel graph which has a call input.
std::unordered_map<KernelGraphPtr, DeviceContext *> call_input_kernel_graphs_;
// The dependency between kernel and call node in auto monad.
std::unordered_map<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
// Control nodes without a control node input in the topological sorting of funcgraph.
std::unordered_map<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_;
// In heterogeneous scenario, each parameter has its own device context type, so the device context corresponding
// to the type needs to be parsed in advance so that it can add some copy operation in the scheduler.
// 1. The device context type of the formal parameters of funcgraph.
std::unordered_map<FuncGraphPtr, std::vector<const DeviceContext *>> func_graph_to_device_contexts_;
// 2. The device context type of the control node inputs.
std::unordered_map<AnfNodePtr, std::vector<const DeviceContext *>> control_node_to_device_contexts_;
// Is control flow enable.
bool is_inited_{false};
// Root funcgraph and its parameters.
FuncGraphPtr root_func_graph_;
std::vector<AnfNodePtr> root_graph_parameters_;
// The dependency between kernel and call node in auto monad.
std::unordered_map<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
// Call node will call different funcgraphs according to the input partial node, and this relationship is recorded
// in this map.
std::unordered_map<AnfNodePtr, std::set<FuncGraphPtr>> call_node_to_func_graphs_;
// In heterogeneous scenarios, different formal parameters of funcgraph will have different contexts. In order to
// ensure that there is no copy actor between control actors, the device context type corresponding to each formal
// parameter needs to be derived in the parser and recorded in this map.
std::unordered_map<FuncGraphPtr, std::vector<const DeviceContext *>> func_graph_to_device_contexts_;
std::unordered_map<AnfNodePtr, std::vector<const DeviceContext *>> control_node_to_device_contexts_;
// Record which kernel graph the front node is in.
FrontNodeToKernelGraph front_node_to_kernel_graph_;
bool is_inited_{false};
};
using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>;

View File

@ -127,7 +127,8 @@ std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCo
// The gather actor corresponding to a call node needs to set the branch id.
if (AnfAlgo::IsCallNode(control_node)) {
gather_actor->output_branch_id_ = graph_compiler_info.control_node_parser_->GetBranchIDByCallNode(control_node);
gather_actor->output_branch_id_ =
graph_compiler_info.control_node_parser_->FetchBranchIDByCallNode(control_node);
}
}
}
@ -404,7 +405,7 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto exit_actor = dynamic_cast<ExitActor *>(actor);
size_t branch_id = parser->GetBranchIDByCallNode(from_node);
size_t branch_id = parser->FetchBranchIDByCallNode(from_node);
LinkDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second,
branch_id);
}

View File

@ -92,6 +92,29 @@ std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set) {
if (actor_set->output_actor_ != nullptr) {
(void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->output_actor_));
}
if (actor_set->control_actors_ != nullptr) {
const auto &control_actor_set = actor_set->control_actors_;
for (auto &switch_actor : control_actor_set->switch_actors_) {
MS_EXCEPTION_IF_NULL(switch_actor);
(void)actors.emplace_back(static_cast<AbstractActorPtr>(switch_actor));
}
for (auto &gather_actor : control_actor_set->gather_actors_) {
MS_EXCEPTION_IF_NULL(gather_actor);
(void)actors.emplace_back(static_cast<AbstractActorPtr>(gather_actor));
}
for (auto &entrance_actor : control_actor_set->entrance_actors_) {
MS_EXCEPTION_IF_NULL(entrance_actor);
(void)actors.emplace_back(static_cast<AbstractActorPtr>(entrance_actor));
}
for (auto &exit_actor : control_actor_set->exit_actors_) {
MS_EXCEPTION_IF_NULL(exit_actor);
(void)actors.emplace_back(static_cast<AbstractActorPtr>(exit_actor));
}
for (auto &stack_actor : control_actor_set->stack_actors_) {
MS_EXCEPTION_IF_NULL(stack_actor);
(void)actors.emplace_back(static_cast<AbstractActorPtr>(stack_actor));
}
}
return actors;
}
@ -487,6 +510,8 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
(void)host_queue_ds_actor->data_nodes_.emplace_back(input_node);
(void)host_queue_ds_actor->device_contexts_.emplace_back(device_context);
(void)host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
// In control flow, need to rely on the front node to find the location of the corresponding real parameter.
(void)host_queue_ds_actor->data_node_position_map_.emplace(front_node, data_node_position);
(void)front_node_position_temp_map.emplace(front_node, data_node_position);
data_node_position++;
}
@ -525,7 +550,7 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
continue;
}
auto backend_iter = front_to_backend_parameter.find(parameter);
if (backend_iter == front_to_backend_parameter.end()) {
if (backend_iter == front_to_backend_parameter.end() || backend_iter->second.empty()) {
MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
}
@ -538,15 +563,20 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
(void)data_source_actors.emplace_back(host_queue_ds_actor);
}
const auto &backend_node = backend_iter->second.first;
if (host_queue_ds_actor->data_node_position_map_.find(parameter) !=
host_queue_ds_actor->data_node_position_map_.end()) {
continue;
}
const auto &backend_node = backend_iter->second.begin()->first;
auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
if (iter != host_queue_ds_actor->data_nodes_.end()) {
(void)host_queue_ds_actor->data_node_position_map_.emplace(parameter,
iter - host_queue_ds_actor->data_nodes_.begin());
} else {
(void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
(void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.first);
(void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.second);
(void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.begin()->first);
(void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.begin()->second);
}
}
@ -1297,15 +1327,17 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun
(void)no_output_actors.emplace_back(super_actor.get());
}
}
// In control flow scenario, no output actor needs to be connected to the corresponding exit actor, not loop count.
if (!parser->IsInited()) {
for (auto &kernel_actor : actor_set->kernel_actors_) {
// The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) {
MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
MS_LOG(INFO) << kernel_actor->kernel_->fullname_with_scope() << " is not real used by other nodes.";
if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) {
(void)no_output_actors.emplace_back(kernel_actor.get());
}
}
}
for (auto &data_actor : actor_set->data_source_actors_) {
if ((data_actor->output_data_arrows_.size() == 0) && (data_actor->output_control_arrows_.size() == 0)) {
(void)no_output_actors.emplace_back(data_actor.get());
@ -1332,7 +1364,9 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun
void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
const GraphCompilerInfo &graph_compiler_info) {
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep ||
(graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited())) {
// In control flow, the exit actor of the root graph sends output data to the output actor.
return;
}
MS_EXCEPTION_IF_NULL(to_actor);
@ -1706,6 +1740,7 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf
DumpCopyActors(actor_set->copy_actors_, ofs);
DumpLoopCountActor(actor_set->loop_count_actor_, ofs);
DumpOutputActor(actor_set->output_actor_, ofs);
DumpControlActors(actor_set->control_actors_, ofs);
}
void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {

View File

@ -379,6 +379,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
// Compile root graph.
graph_id_to_device_context_.clear();
func_graph_to_kernel_graph_ids_.clear();
control_nodes_.clear();
auto subgraph_need_compile = CompileGraph(root_graph);
@ -476,6 +477,10 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu
}
graph_id_to_device_context_[graph_id] = device_context;
const auto &func_graph = segment->nodes_[0]->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
func_graph_to_kernel_graph_ids_[func_graph].emplace_back(graph_id);
} else {
// Compile the cut node.
auto cut_node = segment->nodes_[0];
@ -971,8 +976,18 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
(void)name.append("_").append(std::to_string(graph_id_to_context.first));
}
FuncGraphToKernelGraph func_graph_to_kernel_graphs;
for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
const auto &func_graph = func_graph_to_kernel_graph_ids.first;
for (const auto &graph_id : func_graph_to_kernel_graph_ids.second) {
const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graph);
}
}
auto parser = std::make_shared<ControlNodeParser>();
parser->Parse(control_nodes_, graphs, device_contexts, root_graph);
parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs);
runtime::KernelMapPosition outputs_order;
size_t outputs_num = 0;

View File

@ -42,6 +42,7 @@ using ActorInfo = runtime::ActorInfo;
using GraphCompiler = runtime::GraphCompiler;
using GraphCompilerInfo = runtime::GraphCompilerInfo;
using ControlNodeParser = runtime::ControlNodeParser;
using FuncGraphToKernelGraph = runtime::FuncGraphToKernelGraph;
using ControlNodeParserPtr = runtime::ControlNodeParserPtr;
using KernelWithIndex = session::KernelWithIndex;
@ -157,6 +158,8 @@ class MindRTBackend : public Backend {
// node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to
// the corresponding device_context.
std::map<GraphId, DeviceContext *> graph_id_to_device_context_;
// Funcgraph will be cut into multiple kernel graphs, and the map is used to save the correspondence.
std::map<FuncGraphPtr, std::vector<GraphId>> func_graph_to_kernel_graph_ids_;
std::map<GraphInfo, DeviceContext *> graph_info_to_device_context_;
std::vector<AnfNodePtr> control_nodes_;