!18891 Fix funcgraph recursive.
Merge pull request !18891 from gaoyong10/new_runtime10
This commit is contained in:
commit
787c87d8a0
|
@ -63,70 +63,14 @@ void GatherActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
|
|||
}
|
||||
}
|
||||
|
||||
void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph,
|
||||
const std::vector<AnfNodePtr> &origin_parameters_order,
|
||||
const FrontToBackendNodeWithContext &front_to_backend_parameters,
|
||||
const FuncGraphToParameter &func_graph_to_parameters,
|
||||
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel) {
|
||||
const auto func_iter = func_graph_to_parameters.find(func_graph);
|
||||
if (func_iter == func_graph_to_parameters.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> graph_inputs;
|
||||
void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser) {
|
||||
for (const auto &input : func_graph->get_inputs()) {
|
||||
// Monad input would not send to gather actor.
|
||||
if (!HasAbstractMonad(input)) {
|
||||
graph_inputs.emplace_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all backend input node to gather, There are two situations:
|
||||
// 1. The parameter from the host data source.
|
||||
// 2. Output the kernel actor.
|
||||
for (const auto parameters : func_iter->second) {
|
||||
if (parameters.size() != graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size() << " need:" << graph_inputs.size()
|
||||
<< " func_graph:" << func_iter->first->ToString();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
if (parameters[i]->isa<Parameter>()) {
|
||||
// Input node is a parameter from host data source actor.
|
||||
std::vector<AnfNodePtr> invalid_inputs;
|
||||
std::vector<AnfNodePtr> front_inputs =
|
||||
FetchInputNodeByParameter(parameters[i], origin_parameters_order, &invalid_inputs, func_graph_to_parameters);
|
||||
|
||||
for (const auto &front_input : front_inputs) {
|
||||
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0);
|
||||
|
||||
if (node_with_index.first->isa<Parameter>()) {
|
||||
const auto &iter = front_to_backend_parameters.find(parameters[i]);
|
||||
if (iter == front_to_backend_parameters.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node of node:"
|
||||
<< AnfAlgo::GetNodeDebugString(node_with_index.first);
|
||||
}
|
||||
front_to_backend_parameter_[graph_inputs[i]].push_back({iter->second.first, 0});
|
||||
} else {
|
||||
const auto iter = front_to_backend_kernel.find(node_with_index.first);
|
||||
if (iter == front_to_backend_kernel.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find actor of front node:"
|
||||
<< AnfAlgo::GetNodeDebugString(node_with_index.first);
|
||||
}
|
||||
front_to_backend_parameter_[graph_inputs[i]].push_back({iter->second, node_with_index.second});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Input node is a cnode.
|
||||
const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(parameters[i], 0);
|
||||
const auto iter = front_to_backend_kernel.find(node_with_index.first);
|
||||
if (iter == front_to_backend_kernel.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node of node:"
|
||||
<< AnfAlgo::GetNodeDebugString(node_with_index.first);
|
||||
}
|
||||
front_to_backend_parameter_[graph_inputs[i]].push_back({iter->second, node_with_index.second});
|
||||
}
|
||||
if (HasAbstractMonad(input) ||
|
||||
(input->isa<Parameter>() && AnfAlgo::IsParameterWeight(input->cast<ParameterPtr>()))) {
|
||||
continue;
|
||||
}
|
||||
front_to_backend_parameter_[input] = parser->GetBackendInputByParameter(input);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -56,10 +56,7 @@ class GatherActor : public OpActor<DeviceTensor> {
|
|||
friend class GraphScheduler;
|
||||
|
||||
// Collect the inputs of gather actor.
|
||||
void FetchBackendInputNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &origin_parameters_order,
|
||||
const FrontToBackendNodeWithContext &front_to_backend_parameters,
|
||||
const FuncGraphToParameter &func_graph_to_parameters,
|
||||
const std::unordered_map<AnfNodePtr, AnfNodePtr> &front_to_backend_kernel);
|
||||
void FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser);
|
||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
|
||||
// Check whether satisfy the condition for launch.
|
||||
bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const;
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -47,22 +47,32 @@ using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>
|
|||
// 2. First input of node is a funcgraph value node.
|
||||
bool IsCallNode(const AnfNodePtr &node);
|
||||
|
||||
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
|
||||
// it is determined whether it is a weight.
|
||||
bool HasAbstractRef(const AnfNodePtr &node);
|
||||
|
||||
// Recursive interface, get the funcgraph which the node belongs, if the node has a front node, return the funcgraph
|
||||
// which the front node belongs, if not, find the funcgraph which the input of the node belongs.
|
||||
FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node);
|
||||
|
||||
// Recursive interface, get the number of output nodes of funcgraph called by call node.
|
||||
size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes);
|
||||
|
||||
// Get front node by backend node.
|
||||
AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
|
||||
|
||||
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
|
||||
// corresponding cnode.
|
||||
AnfNodePtr GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);
|
||||
|
||||
// Get the funcgraph to which the node belongs.
|
||||
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
|
||||
|
||||
// Find all funcgraphs that the call node will call.
|
||||
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const CNodePtr &node);
|
||||
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node);
|
||||
|
||||
// Fetch all backend input nodes by parameter for gather actor.
|
||||
std::vector<AnfNodePtr> FetchInputNodeByParameter(const AnfNodePtr ¶meter,
|
||||
const std::vector<AnfNodePtr> &host_ds_parameters,
|
||||
std::vector<AnfNodePtr> *invalid_inputs,
|
||||
const FuncGraphToParameter &graph_to_real_parameters);
|
||||
// Recursive interface, get all input of make tuple node.
|
||||
std::vector<AnfNodePtr> FetchInputsByMakeTuple(const AnfNodePtr &node);
|
||||
|
||||
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
|
||||
class ControlNodeParser {
|
||||
|
@ -77,6 +87,26 @@ class ControlNodeParser {
|
|||
// multiple branch outputs, there will be multiple output nodes.
|
||||
std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
|
||||
|
||||
// Get all possible input nodes of the output node. When the switch actor is the output, it need to send the node
|
||||
// which device address belongs, so switch actor need to get all the possible nodes.
|
||||
std::vector<KernelWithIndex> FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output);
|
||||
|
||||
// Get the device context corresponding to the value node.
|
||||
DeviceContext *GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node);
|
||||
|
||||
// Get the branch id corresponding to funcgraph.
|
||||
int GetBranchIDByFuncGraph(const FuncGraphPtr &func_graph);
|
||||
|
||||
// Get the number of calls to funcgraph
|
||||
size_t GetCallNumByFuncGraph(const FuncGraphPtr &func_graph);
|
||||
|
||||
// Get all possible input nodes of the output node. When the gather actor is the output, it need to send the node
|
||||
// which device address belongs, so gather actor need to get all the possible nodes.
|
||||
std::vector<KernelWithIndex> GetBackendInputByParameter(const AnfNodePtr ¶meter);
|
||||
|
||||
// Check whether there is a call node in the front input nodes of the kernel graph.
|
||||
bool IsCallInputKernelGraph(const KernelGraphPtr &graph);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
|
@ -84,47 +114,30 @@ class ControlNodeParser {
|
|||
// value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
|
||||
// them separately during initialization.
|
||||
// The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
|
||||
void FetchFrontValueNode(const std::vector<KernelGraphPtr> &graphs,
|
||||
void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts);
|
||||
|
||||
// Create branch id for all subgraphs in the control flow.
|
||||
void CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Find all value nodes in the switch recursively.
|
||||
void FetchValueNodeInSwitchNode(const AnfNodePtr &switch_node, std::vector<AnfNodePtr> *value_nodes);
|
||||
|
||||
void FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, std::vector<AnfNodePtr> *value_nodes);
|
||||
// Fetch all the relationships between front parameters and backend parameters.The front parameters
|
||||
// include two parts:
|
||||
// 1. The parameter from kernel graph.
|
||||
// 2. The parameter from control nodes.
|
||||
void FetchFrontToBackendParameterMap(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
void FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Get the relationship between the front and backend of the executable kernel in all kernel graphs.
|
||||
void FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts);
|
||||
// Get inputs of control node which come from the host actor. These inputs generally come from the partial
|
||||
// nodes and call nodes of the root funcgraph.
|
||||
std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node,
|
||||
// and the input of the call node is the input parameter of the corresponding funcgraph.
|
||||
void FetchFuncGraphToParameterMap(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
void FetchFuncGraphToParameter(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Get all the front weight parameters related to the weight in the host parameter.
|
||||
void FetchHostParameterToWeightMap(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Get the pos input of call node to funcgraph.
|
||||
AnfNodePtr GetCallNodeInputByPos(const AnfNodePtr &call_node, const FuncGraphPtr &func_graph, const size_t pos);
|
||||
|
||||
// Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph
|
||||
// called by the call node.
|
||||
std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *call_nodes);
|
||||
|
||||
// Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding
|
||||
// backend parameter, then recursively find the backend parameters of other front parameters corresponding to the
|
||||
// front_node.
|
||||
std::pair<AnfNodePtr, DeviceContext *> FetchBackendNodeByFrontNode(
|
||||
const AnfNodePtr &front_node,
|
||||
const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &front_to_front_parameter,
|
||||
const std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>> &front_to_backend_parameter,
|
||||
std::set<AnfNodePtr> *invalid_node);
|
||||
|
||||
void FetchHostParameterToWeight(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// The relationship between front parameters indicates that the parameter is directly used as the input of the
|
||||
// funcgraph. There are two situations:
|
||||
// 1. The parameter is used as the input of the call node,
|
||||
|
@ -132,12 +145,44 @@ class ControlNodeParser {
|
|||
// subsequent call node.
|
||||
void FetchFrontToFrontParameterMap(const std::vector<AnfNodePtr> &control_nodes,
|
||||
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
|
||||
// Get the number of calls to all subgraphs in the whole funcgraph.
|
||||
void FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Get all the kernel graphs where the input node has a call node.
|
||||
void FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts);
|
||||
// Get the relationship of all real and formal parameters in the whole funcgraph.
|
||||
void FetchBackendInputNode();
|
||||
|
||||
// Get all possible input node of real parameter.
|
||||
void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter);
|
||||
// Recursive interface, get all Backend node by front_output.
|
||||
std::vector<KernelWithIndex> FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output,
|
||||
std::set<AnfNodePtr> *call_nodes,
|
||||
std::set<AnfNodePtr> *switch_nodes);
|
||||
|
||||
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
|
||||
FrontToBackendNodeWithContext front_to_backend_parameters_;
|
||||
|
||||
// The relationship between all real parameters and formal parameters in the entire func_graph.
|
||||
// In control flow, the control actor will be the output actor. Since the actor needs to send the node to the output
|
||||
// actor, it is necessary to save all the real parameters corresponding to the formal parameters in the control actor.
|
||||
// When the control actor receives the device address, it can find the corresponding input node.
|
||||
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> formal_to_real_parameters_;
|
||||
|
||||
// Relationship between the front and backend of the executable kernel in all kernel graphs.
|
||||
FrontToBackendNodeWithContext front_to_backend_kernels_;
|
||||
|
||||
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
|
||||
// the input node of gather.
|
||||
FuncGraphToParameter func_graph_to_parameters_;
|
||||
|
||||
// Branch id of funcgraph.
|
||||
// In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to
|
||||
// different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch
|
||||
// id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id
|
||||
// to its output switch actor.
|
||||
std::unordered_map<FuncGraphPtr, int> func_graph_to_branch_id_;
|
||||
|
||||
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
|
||||
// When initializing the weights, all related weights need to be recorded as the same device tensor.
|
||||
HostParameterToWeight host_parameter_to_weights_;
|
||||
|
@ -148,6 +193,13 @@ class ControlNodeParser {
|
|||
std::vector<AnfNodePtr> front_output_nodes_;
|
||||
// Parameters of control node which come from the host actor.
|
||||
std::vector<AnfNodePtr> control_node_parameters_;
|
||||
// The number of calls to func_graph.
|
||||
std::unordered_map<FuncGraphPtr, size_t> func_graph_to_call_num_;
|
||||
// The kernel graph of call exists in the front-end input node.
|
||||
std::unordered_map<KernelGraphPtr, DeviceContext *> call_input_kernel_graphs_;
|
||||
// In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor
|
||||
// is created for kernel graph which has a call input.
|
||||
std::vector<AnfNodePtr> root_graph_parameters_;
|
||||
};
|
||||
|
||||
using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>;
|
||||
|
|
|
@ -1062,10 +1062,7 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
|
|||
|
||||
auto gather_actor =
|
||||
std::make_shared<GatherActor>(actor_name, parameters, loop_count_actor->GetAID(), output_actor->GetAID());
|
||||
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.origin_parameters_order_,
|
||||
graph_compiler_info.control_node_parser_->front_to_backend_parameters_,
|
||||
graph_compiler_info.control_node_parser_->func_graph_to_parameters_,
|
||||
front_to_backend_kernel);
|
||||
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
|
||||
InsertActor(gather_actor.get());
|
||||
gather_actors.emplace_back(gather_actor);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue