!18891 Fix funcgraph recursive.

Merge pull request !18891 from gaoyong10/new_runtime10
This commit is contained in:
i-robot 2021-06-26 08:07:00 +00:00 committed by Gitee
commit 787c87d8a0
5 changed files with 806 additions and 317 deletions

View File

@ -63,70 +63,14 @@ void GatherActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
}
}
void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph,
const std::vector<AnfNodePtr> &origin_parameters_order,
const FrontToBackendNodeWithContext &front_to_backend_parameters,
const FuncGraphToParameter &func_graph_to_parameters,
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel) {
const auto func_iter = func_graph_to_parameters.find(func_graph);
if (func_iter == func_graph_to_parameters.end()) {
return;
}
std::vector<AnfNodePtr> graph_inputs;
void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser) {
for (const auto &input : func_graph->get_inputs()) {
// Monad input would not send to gather actor.
if (!HasAbstractMonad(input)) {
graph_inputs.emplace_back(input);
}
}
// Collect all backend input node to gather, There are two situations:
// 1. The parameter from the host data source.
// 2. Output the kernel actor.
for (const auto parameters : func_iter->second) {
if (parameters.size() != graph_inputs.size()) {
MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size() << " need:" << graph_inputs.size()
<< " func_graph:" << func_iter->first->ToString();
}
for (size_t i = 0; i < parameters.size(); ++i) {
if (parameters[i]->isa<Parameter>()) {
// Input node is a parameter from host data source actor.
std::vector<AnfNodePtr> invalid_inputs;
std::vector<AnfNodePtr> front_inputs =
FetchInputNodeByParameter(parameters[i], origin_parameters_order, &invalid_inputs, func_graph_to_parameters);
for (const auto &front_input : front_inputs) {
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0);
if (node_with_index.first->isa<Parameter>()) {
const auto &iter = front_to_backend_parameters.find(parameters[i]);
if (iter == front_to_backend_parameters.end()) {
MS_LOG(EXCEPTION) << "Cannot find backend node of node:"
<< AnfAlgo::GetNodeDebugString(node_with_index.first);
}
front_to_backend_parameter_[graph_inputs[i]].push_back({iter->second.first, 0});
} else {
const auto iter = front_to_backend_kernel.find(node_with_index.first);
if (iter == front_to_backend_kernel.end()) {
MS_LOG(EXCEPTION) << "Cannot find actor of front node:"
<< AnfAlgo::GetNodeDebugString(node_with_index.first);
}
front_to_backend_parameter_[graph_inputs[i]].push_back({iter->second, node_with_index.second});
}
}
} else {
// Input node is a cnode.
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(parameters[i], 0);
const auto iter = front_to_backend_kernel.find(node_with_index.first);
if (iter == front_to_backend_kernel.end()) {
MS_LOG(EXCEPTION) << "Cannot find backend node of node:"
<< AnfAlgo::GetNodeDebugString(node_with_index.first);
}
front_to_backend_parameter_[graph_inputs[i]].push_back({iter->second, node_with_index.second});
}
if (HasAbstractMonad(input) ||
(input->isa<Parameter>() && AnfAlgo::IsParameterWeight(input->cast<ParameterPtr>()))) {
continue;
}
front_to_backend_parameter_[input] = parser->GetBackendInputByParameter(input);
}
}

View File

@ -56,10 +56,7 @@ class GatherActor : public OpActor<DeviceTensor> {
friend class GraphScheduler;
// Collect the inputs of gather actor.
void FetchBackendInputNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &origin_parameters_order,
const FrontToBackendNodeWithContext &front_to_backend_parameters,
const FuncGraphToParameter &func_graph_to_parameters,
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel);
void FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser);
void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
// Check whether satisfy the condition for launch.
bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const;

File diff suppressed because it is too large Load Diff

View File

@ -47,22 +47,32 @@ using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>
// 2. First input of node is a funcgraph value node.
bool IsCallNode(const AnfNodePtr &node);
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
// it is determined whether it is a weight.
bool HasAbstractRef(const AnfNodePtr &node);
// Recursive interface, get the funcgraph which the node belongs, if the node has a front node, return the funcgraph
// which the front node belongs, if not, find the funcgraph which the input of the node belongs.
FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node);
// Recursive interface, get the number of output nodes of funcgraph called by call node.
size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes);
// Get front node by backend node.
AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
// corresponding cnode.
AnfNodePtr GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);
// Get the funcgraph to which the node belongs.
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
// Find all funcgraphs that the call node will call.
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const CNodePtr &node);
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node);
// Fetch all backend input nodes by parameter for gather actor.
std::vector<AnfNodePtr> FetchInputNodeByParameter(const AnfNodePtr &parameter,
const std::vector<AnfNodePtr> &host_ds_parameters,
std::vector<AnfNodePtr> *invalid_inputs,
const FuncGraphToParameter &graph_to_real_parameters);
// Recursive interface, get all input of make tuple node.
std::vector<AnfNodePtr> FetchInputsByMakeTuple(const AnfNodePtr &node);
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
class ControlNodeParser {
@ -77,6 +87,26 @@ class ControlNodeParser {
// multiple branch outputs, there will be multiple output nodes.
std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
// Get all possible input nodes of the output node. When the switch actor is the output, it need to send the node
// which device address belongs, so switch actor need to get all the possible nodes.
std::vector<KernelWithIndex> FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output);
// Get the device context corresponding to the value node.
DeviceContext *GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node);
// Get the branch id corresponding to funcgraph.
int GetBranchIDByFuncGraph(const FuncGraphPtr &func_graph);
// Get the number of calls to funcgraph
size_t GetCallNumByFuncGraph(const FuncGraphPtr &func_graph);
// Get all possible input nodes of the output node. When the gather actor is the output, it need to send the node
// which device address belongs, so gather actor need to get all the possible nodes.
std::vector<KernelWithIndex> GetBackendInputByParameter(const AnfNodePtr &parameter);
// Check whether there is a call node in the front input nodes of the kernel graph.
bool IsCallInputKernelGraph(const KernelGraphPtr &graph);
private:
friend class GraphScheduler;
@ -84,47 +114,30 @@ class ControlNodeParser {
// value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
// them separately during initialization.
// The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
void FetchFrontValueNode(const std::vector<KernelGraphPtr> &graphs,
void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts);
// Create branch id for all subgraphs in the control flow.
void CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
// Find all value nodes in the switch recursively.
void FetchValueNodeInSwitchNode(const AnfNodePtr &switch_node, std::vector<AnfNodePtr> *value_nodes);
void FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, std::vector<AnfNodePtr> *value_nodes);
// Fetch all the relationships between front parameters and backend parameters.The front parameters
// include two parts:
// 1. The parameter from kernel graph.
// 2. The parameter from control nodes.
void FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes);
void FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes);
// Get the relationship between the front and backend of the executable kernel in all kernel graphs.
void FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts);
// Get inputs of control node which come from the host actor. These inputs generally come from the partial
// nodes and call nodes of the root funcgraph.
std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
// Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node,
// and the input of the call node is the input parameter of the corresponding funcgraph.
void FetchFuncGraphToParameterMap(const std::vector<AnfNodePtr> &control_nodes);
void FetchFuncGraphToParameter(const std::vector<AnfNodePtr> &control_nodes);
// Get all the front weight parameters related to the weight in the host parameter.
void FetchHostParameterToWeightMap(const std::vector<AnfNodePtr> &control_nodes);
// Get the pos input of call node to funcgraph.
AnfNodePtr GetCallNodeInputByPos(const AnfNodePtr &call_node, const FuncGraphPtr &func_graph, const size_t pos);
// Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph
// called by the call node.
std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *call_nodes);
// Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding
// backend parameter, then recursively find the backend parameters of other front parameters corresponding to the
// front_node.
std::pair<AnfNodePtr, DeviceContext *> FetchBackendNodeByFrontNode(
const AnfNodePtr &front_node,
const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &front_to_front_parameter,
const std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>> &front_to_backend_parameter,
std::set<AnfNodePtr> *invalid_node);
void FetchHostParameterToWeight(const std::vector<AnfNodePtr> &control_nodes);
// The relationship between front parameters indicates that the parameter is directly used as the input of the
// funcgraph. There are two situations:
// 1. The parameter is used as the input of the call node,
@ -132,12 +145,44 @@ class ControlNodeParser {
// subsequent call node.
void FetchFrontToFrontParameterMap(const std::vector<AnfNodePtr> &control_nodes,
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
// Get the number of calls to all subgraphs in the whole funcgraph.
void FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes);
// Get all the kernel graphs where the input node has a call node.
void FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts);
// Get the relationship of all real and formal parameters in the whole funcgraph.
void FetchBackendInputNode();
// Get all possible input node of real parameter.
void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter);
// Recursive interface, get all Backend node by front_output.
std::vector<KernelWithIndex> FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output,
std::set<AnfNodePtr> *call_nodes,
std::set<AnfNodePtr> *switch_nodes);
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
FrontToBackendNodeWithContext front_to_backend_parameters_;
// The relationship between all real parameters and formal parameters in the entire func_graph.
// In control flow, the control actor will be the output actor. Since the actor needs to send the node to the output
// actor, it is necessary to save all the real parameters corresponding to the formal parameters in the control actor.
// When the control actor receives the device address, it can find the corresponding input node.
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> formal_to_real_parameters_;
// Relationship between the front and backend of the executable kernel in all kernel graphs.
FrontToBackendNodeWithContext front_to_backend_kernels_;
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
// the input node of gather.
FuncGraphToParameter func_graph_to_parameters_;
// Branch id of funcgraph.
// In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to
// different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch
// id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id
// to its output switch actor.
std::unordered_map<FuncGraphPtr, int> func_graph_to_branch_id_;
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
// When initializing the weights, all related weights need to be recorded as the same device tensor.
HostParameterToWeight host_parameter_to_weights_;
@ -148,6 +193,13 @@ class ControlNodeParser {
std::vector<AnfNodePtr> front_output_nodes_;
// Parameters of control node which come from the host actor.
std::vector<AnfNodePtr> control_node_parameters_;
// The number of calls to func_graph.
std::unordered_map<FuncGraphPtr, size_t> func_graph_to_call_num_;
// The kernel graph of call exists in the front-end input node.
std::unordered_map<KernelGraphPtr, DeviceContext *> call_input_kernel_graphs_;
// In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor
// is created for kernel graph which has a call input.
std::vector<AnfNodePtr> root_graph_parameters_;
};
using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>;

View File

@ -1062,10 +1062,7 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
auto gather_actor =
std::make_shared<GatherActor>(actor_name, parameters, loop_count_actor->GetAID(), output_actor->GetAID());
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.origin_parameters_order_,
graph_compiler_info.control_node_parser_->front_to_backend_parameters_,
graph_compiler_info.control_node_parser_->func_graph_to_parameters_,
front_to_backend_kernel);
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
InsertActor(gather_actor.get());
gather_actors.emplace_back(gather_actor);
}