forked from mindspore-Ecosystem/mindspore
delete useless interface in control node parser
This commit is contained in:
parent
dbbe870036
commit
b5a9588d10
|
@ -132,16 +132,6 @@ bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsGatherActor(const AnfNodePtr &front_node,
|
||||
const std::unordered_map<std::string, OpActor<DeviceTensor> *> &actor_name_to_actor) {
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
if (front_node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(front_node->cast<ParameterPtr>())) &&
|
||||
(front_node->func_graph() != nullptr) && (actor_name_to_actor.count(front_node->func_graph()->ToString()) > 0)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(dst_device_tensor);
|
||||
MS_EXCEPTION_IF_NULL(src_device_tensor);
|
||||
|
|
|
@ -133,10 +133,6 @@ bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph);
|
|||
// Judge whether the device tensor of the node is persistent or not.
|
||||
bool IsPersistentDeviceTensor(const AnfNodePtr &node);
|
||||
|
||||
// Judge whether the front node is in a gather actor.
|
||||
bool IsGatherActor(const AnfNodePtr &front_node,
|
||||
const std::unordered_map<std::string, OpActor<DeviceTensor> *> &actor_name_to_actor);
|
||||
|
||||
// Copy data from src_device_tensor to dst_device_tensor.
|
||||
bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
|
||||
|
||||
|
|
|
@ -177,14 +177,157 @@ void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) {
|
|||
ofs << "\n";
|
||||
}
|
||||
|
||||
void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
|
||||
const auto &output_data_arrows = actor->output_data_arrows();
|
||||
if (output_data_arrows.size() > 0) {
|
||||
ofs << "\t\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
|
||||
for (const auto &data_arrow : output_data_arrows) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
ofs << "\t\t\t\tfrom_output_index:" << data_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
|
||||
const auto &output_control_arrows = actor->output_control_arrows();
|
||||
if (output_control_arrows.size() > 0) {
|
||||
ofs << "\t\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
|
||||
for (const auto &aid : output_control_arrows) {
|
||||
ofs << "\t\t\t\tto_actor_name:" << aid.Name() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
const auto &output_partial_arrows = actor->output_partial_arrows();
|
||||
if (output_partial_arrows.size() > 0) {
|
||||
ofs << "\t\t\toutput_partial_arrows:" << output_partial_arrows.size() << "\n ";
|
||||
for (const auto &partial_arrow : output_partial_arrows) {
|
||||
MS_EXCEPTION_IF_NULL(partial_arrow);
|
||||
ofs << "\t\t\t\tfrom_output_index:" << partial_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << partial_arrow->to_op_id_.Name()
|
||||
<< "\tto_input_index:" << partial_arrow->to_input_index_ << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
const auto &output_branch_id_arrows = actor->output_branch_id_arrows();
|
||||
if (output_branch_id_arrows.size() > 0) {
|
||||
ofs << "\t\t\toutput_branch_id_arrows:" << output_branch_id_arrows.size() << "\n ";
|
||||
for (const auto &aid : output_branch_id_arrows) {
|
||||
ofs << "\t\t\t\tto_actor_name:" << aid.Name() << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
ofs << "\t\ttactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
}
|
||||
|
||||
void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
|
||||
const auto &output_data_with_branch_id_arrows = actor->output_data_with_branch_id_arrows();
|
||||
if (output_data_with_branch_id_arrows.size() > 0) {
|
||||
ofs << "\t\t\toutput_data_with_branch_id_arrows:" << output_data_with_branch_id_arrows.size() << "\n ";
|
||||
for (const auto &output_data_with_branch_id_arrow : output_data_with_branch_id_arrows) {
|
||||
ofs << "\t\t\t\tbranch funcgraph:" << output_data_with_branch_id_arrow.first->ToString() << "\n";
|
||||
for (const auto &arrow : output_data_with_branch_id_arrow.second) {
|
||||
ofs << "\t\t\t\t\tto actor:" << arrow << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DumpEntranceActor(const EntranceActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
}
|
||||
|
||||
void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
|
||||
const auto &output_branch_data_arrows = actor->output_branch_data_arrows();
|
||||
if (output_branch_data_arrows.size() > 0) {
|
||||
ofs << "\t\t\toutput_branch_data_arrows:" << output_branch_data_arrows.size() << "\n ";
|
||||
for (const auto &output_branch_data_arrow : output_branch_data_arrows) {
|
||||
ofs << "\t\t\t\tbranch id:" << output_branch_data_arrow.first << "\n";
|
||||
for (const auto &arrow : output_branch_data_arrow.second) {
|
||||
MS_EXCEPTION_IF_NULL(arrow);
|
||||
ofs << "\t\t\t\t\tfrom_output_index:" << arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << arrow->to_op_id_.Name() << "\tto_input_index:" << arrow->to_input_index_ << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto &output_branch_partial_arrows = actor->output_branch_partial_arrows();
|
||||
if (output_branch_partial_arrows.size() > 0) {
|
||||
ofs << "\t\t\toutput_branch_partial_arrows:" << output_branch_partial_arrows.size() << "\n ";
|
||||
for (const auto &output_branch_partial_arrow : output_branch_partial_arrows) {
|
||||
ofs << "\t\t\t\tbranch id:" << output_branch_partial_arrow.first << "\n";
|
||||
for (const auto &arrow : output_branch_partial_arrow.second) {
|
||||
MS_EXCEPTION_IF_NULL(arrow);
|
||||
ofs << "\t\t\t\t\tfrom_output_index:" << arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << arrow->to_op_id_.Name() << "\tto_input_index:" << arrow->to_input_index_ << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto &output_branch_control_arrows = actor->output_branch_control_arrows();
|
||||
if (output_branch_control_arrows.size() > 0) {
|
||||
ofs << "\t\t\toutput_branch_control_arrows:" << output_branch_control_arrows.size() << "\n ";
|
||||
for (const auto &output_branch_control_arrow : output_branch_control_arrows) {
|
||||
ofs << "\t\t\t\tbranch id:" << output_branch_control_arrow.first << "\n";
|
||||
for (const auto &arrow : output_branch_control_arrow.second) {
|
||||
ofs << "\t\t\t\t\tto actor:" << arrow << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DumpStackActor(const StackActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
}
|
||||
|
||||
void DumpSwitchActors(const std::vector<SwitchActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n\t[Switch actors:" << actors.size() << "]\n";
|
||||
for (const auto &switch_actor : actors) {
|
||||
DumpSwitchActor(switch_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
|
||||
void DumpGatherActors(const std::vector<GatherActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n\t[Gather actors:" << actors.size() << "]\n";
|
||||
for (const auto &gather_actor : actors) {
|
||||
DumpGatherActor(gather_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
|
||||
void DumpEntranceActors(const std::vector<EntranceActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n\t[Entrance actors:" << actors.size() << "]\n";
|
||||
for (const auto &entrance_actor : actors) {
|
||||
DumpEntranceActor(entrance_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
|
||||
void DumpExitActors(const std::vector<ExitActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n\t[Exit actors:" << actors.size() << "]\n";
|
||||
for (const auto &exit_actor : actors) {
|
||||
DumpExitActor(exit_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
|
||||
void DumpStackActors(const std::vector<StackActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n\t[Stack actors:" << actors.size() << "]\n";
|
||||
for (const auto &stack_actor : actors) {
|
||||
DumpStackActor(stack_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -281,18 +424,17 @@ void DumpCopyActors(const std::vector<CopyActorPtr> &actors, std::ofstream &ofs)
|
|||
}
|
||||
}
|
||||
|
||||
void DumpGatherActors(const std::vector<GatherActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n[Gather actors:" << actors.size() << "]\n";
|
||||
for (const auto &gather_actor : actors) {
|
||||
DumpGatherActor(gather_actor.get(), ofs);
|
||||
void DumpControlActors(const ControlActorSetPtr &control_actor_set, std::ofstream &ofs) {
|
||||
ofs << "\n\n[Control actors]\n";
|
||||
if (control_actor_set == nullptr) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void DumpSwitchActors(const std::vector<SwitchActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n[Switch actors:" << actors.size() << "]\n";
|
||||
for (const auto &switch_actor : actors) {
|
||||
DumpSwitchActor(switch_actor.get(), ofs);
|
||||
}
|
||||
DumpSwitchActors(control_actor_set->switch_actors_, ofs);
|
||||
DumpGatherActors(control_actor_set->gather_actors_, ofs);
|
||||
DumpEntranceActors(control_actor_set->entrance_actors_, ofs);
|
||||
DumpExitActors(control_actor_set->exit_actors_, ofs);
|
||||
DumpStackActors(control_actor_set->stack_actors_, ofs);
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,8 +30,13 @@
|
|||
#include "runtime/framework/actor/super_kernel_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/copy_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/control_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/switch_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/gather_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/entrance_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/exit_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/stack_actor.h"
|
||||
#include "runtime/framework/control_node_scheduler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -43,8 +48,7 @@ void DumpKernelActors(const std::vector<KernelActorPtr> &actors, std::ofstream &
|
|||
void DumpSuperKernelActors(const std::vector<SuperKernelActorPtr> &actors, std::ofstream &ofs);
|
||||
void DumpNoInputKernelActors(const std::vector<AbstractActorPtr> &actors, std::ofstream &ofs);
|
||||
void DumpCopyActors(const std::vector<CopyActorPtr> &actors, std::ofstream &ofs);
|
||||
void DumpGatherActors(const std::vector<GatherActorPtr> &actors, std::ofstream &ofs);
|
||||
void DumpSwitchActors(const std::vector<SwitchActorPtr> &actors, std::ofstream &ofs);
|
||||
void DumpControlActors(const ControlActorSetPtr &control_actor_set, std::ofstream &ofs);
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -75,6 +75,10 @@ void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
|||
}
|
||||
|
||||
void ExitActor::CopyDeviceAddress() {
|
||||
// If node is not empty, it is the exit of funcgraph, no need to create device address.
|
||||
if (node_ != nullptr) {
|
||||
return;
|
||||
}
|
||||
std::vector<DeviceTensor *> new_device_tensors;
|
||||
for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
|
||||
auto input_device_tensor = input_device_tensors_[i];
|
||||
|
|
|
@ -28,6 +28,7 @@ void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
|||
|
||||
ControlActor::FetchInput(context);
|
||||
output_partial_ = input_partials_[0];
|
||||
MS_EXCEPTION_IF_NULL(output_partial_.first);
|
||||
|
||||
// Put other real parameter in partial.
|
||||
for (const auto &device_tensor : input_device_tensors_) {
|
||||
|
|
|
@ -459,10 +459,10 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
|
|||
}
|
||||
|
||||
// In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor.
|
||||
void DataPrepareActor::PrepareDataForControlWeightNode(
|
||||
const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, const DeviceContext *device_context,
|
||||
const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &host_parameter_to_weights,
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
void DataPrepareActor::PrepareDataForControlWeightNode(const AnfNodePtr &node, const AnfNodePtr &front_node,
|
||||
const TensorPtr &tensor, const DeviceContext *device_context,
|
||||
const HostParameterToWeight &host_parameter_to_weights,
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
@ -516,12 +516,12 @@ void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeP
|
|||
if (IsPersistentDeviceTensor(input_node)) {
|
||||
const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters();
|
||||
const auto &iter = front_to_backend_parameters.find(input_node);
|
||||
if (iter == front_to_backend_parameters.end()) {
|
||||
if (iter == front_to_backend_parameters.end() || iter->second.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
|
||||
<< AnfAlgo::GetNodeDebugString(input_node);
|
||||
}
|
||||
const auto &node_with_context = iter->second;
|
||||
PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
|
||||
const auto &node_with_context = iter->second.begin();
|
||||
PrepareDataForControlWeightNode(node_with_context->first, input_node, input_tensor, node_with_context->second,
|
||||
control_node_parser->host_parameter_to_weights(), context);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -93,10 +93,10 @@ class DataPrepareActor : public DebugAwareActor {
|
|||
std::vector<TensorPtr> *const host_tensors,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
// In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor.
|
||||
void PrepareDataForControlWeightNode(
|
||||
const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, const DeviceContext *device_context,
|
||||
const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &host_parameter_to_weights,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
void PrepareDataForControlWeightNode(const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor,
|
||||
const DeviceContext *device_context,
|
||||
const HostParameterToWeight &host_parameter_to_weights,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
|
||||
const GraphCompilerInfo *graph_compiler_info_;
|
||||
GraphExecutionStrategy strategy_;
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
|
@ -56,102 +57,55 @@ constexpr size_t kSingleControlNode = 1;
|
|||
const char kEntranceActorNameSuffix[] = "_EntranceActor";
|
||||
const char kStackActorNameSuffix[] = "_StackActor";
|
||||
|
||||
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, DeviceContext *>>>;
|
||||
using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
|
||||
using FuncGraphToKernelGraph = std::unordered_map<FuncGraphPtr, std::vector<KernelGraphPtr>>;
|
||||
using FuncGraphToParameter = std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>>;
|
||||
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
|
||||
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>>;
|
||||
using NodeWithDeviceContext = std::set<std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
|
||||
using FormalToRealParameter = std::unordered_map<AnfNodePtr, std::set<KernelWithIndex>>;
|
||||
using RealToFormalParameter = std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>>;
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
using FrontNodeToKernelGraph = std::unordered_map<AnfNodePtr, KernelGraphPtr>;
|
||||
|
||||
// Check if the call node is the input of another call node.
|
||||
bool IsSubCallNode(const AnfNodePtr &node);
|
||||
|
||||
// Recursive interface, find the real output of funcgraph called by call node.
|
||||
AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes);
|
||||
using FuncGraphCallRelation = std::unordered_map<FuncGraphPtr, std::vector<std::set<FuncGraphPtr>>>;
|
||||
|
||||
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
|
||||
// it is determined whether it is a weight.
|
||||
bool HasAbstractRef(const AnfNodePtr &node);
|
||||
|
||||
// Recursive interface, get the funcgraph which the node belongs, if the node has a front node, return the funcgraph
|
||||
// which the front node belongs, if not, find the funcgraph which the input of the node belongs.
|
||||
FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node);
|
||||
|
||||
// Recursive interface, get the number of output nodes of funcgraph called by call node.
|
||||
size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes);
|
||||
|
||||
// Get front node by backend node.
|
||||
AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
|
||||
|
||||
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
|
||||
// corresponding cnode.
|
||||
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);
|
||||
|
||||
// Get the funcgraph to which the node belongs.
|
||||
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
|
||||
|
||||
// Find all funcgraphs that the call node will call.
|
||||
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node);
|
||||
|
||||
// Get parameters in kernel graph.
|
||||
std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &graph);
|
||||
|
||||
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
|
||||
class ControlNodeParser {
|
||||
public:
|
||||
// Parse the control node and put the results of the parsing into member variables.
|
||||
void Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph);
|
||||
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph,
|
||||
const FuncGraphToKernelGraph &func_graph_to_kernel_graphs);
|
||||
|
||||
bool IsInited() { return is_inited_; }
|
||||
// Check whether there is a call node in the front input nodes of the kernel graph.
|
||||
bool IsCallInputKernelGraph(const KernelGraphPtr &graph);
|
||||
// Check whether the data arrow of the kernel actor needs to be connected to the control actor.
|
||||
// There are two situations:
|
||||
// 1. In control flow, the parameter input needs to be connected to the entrance actor of the funcgraph.
|
||||
// 2. In the kernel graph with call node input, the data arrow needs to be connected to the stack actor.
|
||||
bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node);
|
||||
|
||||
const std::vector<AnfNodePtr> &control_node_parameters() const { return control_node_parameters_; }
|
||||
const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; }
|
||||
const HostParameterToWeight &host_parameter_to_weights() const { return host_parameter_to_weights_; }
|
||||
const NodeWithDeviceContext &front_value_nodes() const { return front_value_nodes_; }
|
||||
|
||||
// Get the output of funcgraph, usually there is only one output node, In the control flow, there are
|
||||
// multiple branch outputs, there will be multiple output nodes.
|
||||
std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
|
||||
|
||||
// Get all possible input nodes of the output node. When the switch actor is the output, it need to send the node
|
||||
// which device address belongs, so switch actor need to get all the possible nodes.
|
||||
std::set<KernelWithIndex> FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output);
|
||||
|
||||
// Get the device context corresponding to the value node.
|
||||
DeviceContext *GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node);
|
||||
|
||||
// Get the branch id corresponding to call node.
|
||||
int GetBranchIDByCallNode(const AnfNodePtr &call_node);
|
||||
|
||||
// Get the number of calls to funcgraph
|
||||
size_t GetCallNumByFuncGraph(const FuncGraphPtr &func_graph);
|
||||
|
||||
// Get all possible input nodes of the output node. When the gather actor is the output, it need to send the node
|
||||
// which device address belongs, so gather actor need to get all the possible nodes.
|
||||
std::vector<KernelWithIndex> GetBackendInputByParameter(const AnfNodePtr ¶meter);
|
||||
|
||||
// Check whether there is a call node in the front input nodes of the kernel graph.
|
||||
bool IsCallInputKernelGraph(const KernelGraphPtr &graph);
|
||||
|
||||
// Check whether the kernel actor belongs to the root graph.
|
||||
// In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output should
|
||||
// be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be sent to the
|
||||
// output switch actor.
|
||||
bool IsKernelInRootFuncGraph(const AnfNodePtr &kernel);
|
||||
|
||||
// Get the backend node corresponding to the weight node in the subgraph.
|
||||
AnfNodePtr FetchBackendNodebyWeightNode(const AnfNodePtr &node);
|
||||
|
||||
KernelWithIndex GetBackendKernelByFrontKernel(const KernelWithIndex &front_node_with_index) {
|
||||
return front_to_backend_kernels_[front_node_with_index].first;
|
||||
}
|
||||
|
||||
AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node);
|
||||
KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index);
|
||||
// Fetch all funcgraphs that the call node may call.
|
||||
const std::set<FuncGraphPtr> &FetchFuncGraphbyCallNode(const AnfNodePtr &control_node);
|
||||
// Fetch the branch id corresponding to funcgraph.
|
||||
int FetchBranchIDByCallNode(const AnfNodePtr &call_node);
|
||||
// Fetch the funcgraph which the kernel belongs.
|
||||
FuncGraphPtr FetchKernelGraphByFrontNode(const AnfNodePtr &kernel);
|
||||
// Fetch the backend kernel of front node.
|
||||
KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index);
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
@ -160,134 +114,120 @@ class ControlNodeParser {
|
|||
// value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
|
||||
// them separately during initialization.
|
||||
// The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
|
||||
void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts);
|
||||
// Create branch id for all subgraphs in the control flow.
|
||||
void CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Find all value nodes in the switch recursively.
|
||||
void FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, std::vector<AnfNodePtr> *value_nodes);
|
||||
// Fetch all the relationships between front parameters and backend parameters.The front parameters
|
||||
void FetchFrontValueNode();
|
||||
// Create branch id for all call node in the control flow.
|
||||
void CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Parse all the relationships between front parameters and backend parameters.The front parameters
|
||||
// include two parts:
|
||||
// 1. The parameter from kernel graph.
|
||||
// 2. The parameter from control nodes.
|
||||
void FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const RealToFormalNode &real_to_formal_front_parameters,
|
||||
const RealToFormalNode &formal_to_real_front_parameters);
|
||||
// Get the relationship between the front and backend of the executable kernel in all kernel graphs.
|
||||
void FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts);
|
||||
// Get inputs of control node which come from the host actor. These inputs generally come from the partial
|
||||
// nodes and call nodes of the root funcgraph.
|
||||
std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes,
|
||||
DeviceContext *device_context);
|
||||
// Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node,
|
||||
// and the input of the call node is the input parameter of the corresponding funcgraph.
|
||||
void FetchFuncGraphToParameter(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Get all the front weight parameters related to the weight in the host parameter.
|
||||
void FetchHostParameterToWeight(const RealToFormalNode &real_to_formal_front_parameters);
|
||||
void ParseFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts);
|
||||
// The relationship between front parameters indicates that the parameter is directly used as the input of the
|
||||
// funcgraph. There are two situations:
|
||||
// 1. The parameter is used as the input of the call node,
|
||||
// 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the
|
||||
// subsequent call node.
|
||||
void FetchFrontToFrontParameter(const std::vector<AnfNodePtr> &control_nodes,
|
||||
std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
|
||||
// Get the number of calls to all subgraphs in the whole funcgraph.
|
||||
void FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes);
|
||||
void ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Recursively get all the real parameters corresponding to the formal parameters.
|
||||
void ParseAllRealParameterByFormalParameter(const AnfNodePtr &formal_parameter,
|
||||
const FormalToRealParameter &formal_to_real_parameters,
|
||||
std::set<KernelWithIndex> *total_real_parameters,
|
||||
std::set<AnfNodePtr> *invalid_real_parameter);
|
||||
|
||||
// Parse the device context of the control node. In a heterogeneous scenario, different device contexts need to be
|
||||
// copied between different device memories. The analysis steps:
|
||||
// 1. Get the device context of the funcgraph parameter according to the device type of the kernel in the funcgraph.
|
||||
// 2. Determine the type of device context output by funcgraph according to the call relationship of funcgrpah.
|
||||
void ParseDeviceContext(const std::vector<AnfNodePtr> &control_nodes,
|
||||
const std::vector<KernelGraphPtr> &kernel_graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const FuncGraphToKernelGraph &func_graph_to_kernel_graphs);
|
||||
void ParseDeviceContextForFuncGraph(const std::vector<AnfNodePtr> &control_nodes,
|
||||
const std::vector<KernelGraphPtr> &kernel_graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const FuncGraphToKernelGraph &func_graph_to_kernel_graphs);
|
||||
void ParseDeviceContextForControlNode(const DeviceContext *default_context);
|
||||
|
||||
// In the actor model, when the funcgraph comes to an end temporarily, the exit of the funcgraph needs to notify
|
||||
// the entrance actor so that it can process next parameters. This is used to obtain the nodes corresponding to all
|
||||
// actors in the funcgraph that need to send control messages to the entrance.
|
||||
// These node are control nodes without control node input in the topological sort of the funcgraph.
|
||||
void ParseFirstControlNodeForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Parse all funcgraphs that call nodes may call.
|
||||
void ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// Get the relationship between the front and backend of the executable kernel in all kernel graphs.
|
||||
void FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts);
|
||||
void FetchFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs);
|
||||
// nodes and call nodes of the root funcgraph.
|
||||
void FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Get all the front weight parameters related to the weight in the host parameter.
|
||||
void FetchHostParameterToWeight();
|
||||
// Get all the kernel graphs where the input node has a call node.
|
||||
void FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts);
|
||||
// Get the relationship of all real and formal nodes in the whole funcgraph.
|
||||
void FetchBackendInputNode(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const RealToFormalNode &real_to_formal_front_parameters,
|
||||
const RealToFormalNode &formal_to_real_front_parameters);
|
||||
// Get the relationship of all real and formal parameters in the whole funcgraph.
|
||||
void FetchBackendParameterNode(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const RealToFormalNode &real_to_formal_front_parameters,
|
||||
const RealToFormalNode &formal_to_real_front_parameters,
|
||||
FrontToBackendNodeWithContext *front_to_backend_parameters);
|
||||
// Get all possible input node of real parameter.
|
||||
void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter,
|
||||
const FrontToBackendNodeWithContext &front_to_backend_parameters);
|
||||
// Recursive interface, get all Backend node by front_output.
|
||||
void FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, std::set<AnfNodePtr> *call_nodes,
|
||||
std::set<AnfNodePtr> *switch_nodes, std::set<KernelWithIndex> *results);
|
||||
|
||||
// Get the dependency between kernel and call node in auto monad.
|
||||
void FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Fetch the formal parameter in root graph by parameters in subgraph.
|
||||
AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node);
|
||||
|
||||
// In control flow, funcgraph will be cut into multiple kernel graphs for execution, and this relationship is recorded
|
||||
// in this map.
|
||||
FuncGraphToKernelGraph func_graph_to_kernel_graphs_;
|
||||
// The kernel graph to which the front node belongs after the funcgraph is cut.
|
||||
FrontNodeToKernelGraph front_node_to_kernel_graph_;
|
||||
|
||||
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
|
||||
FrontToBackendNodeWithContext front_to_backend_parameters_;
|
||||
|
||||
// The relationship between all real parameters and formal parameters in the entire func_graph.
|
||||
// In control flow, the control actor will be the output actor. Since the actor needs to send the node to the output
|
||||
// actor, it is necessary to save all the real parameters corresponding to the formal parameters in the control actor.
|
||||
// When the control actor receives the device address, it can find the corresponding input node.
|
||||
std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> formal_to_real_parameters_;
|
||||
|
||||
// Relationship between the front and backend of the executable kernel in all kernel graphs.
|
||||
FrontToBackendKernelWithContext front_to_backend_kernels_;
|
||||
|
||||
// The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
|
||||
// the input node of gather.
|
||||
FuncGraphToParameter func_graph_to_parameters_;
|
||||
// Relationship between formal parameters and real parameters.
|
||||
FormalToRealParameter formal_to_real_parameters_;
|
||||
RealToFormalParameter real_to_formal_parameters_;
|
||||
|
||||
// The relationship between the valuenode inputs of the call node and the backend parameter
|
||||
std::map<KernelWithIndex, std::pair<AnfNodePtr, DeviceContext *>> call_node_to_backend_parameters_;
|
||||
|
||||
// Branch id of call node.
|
||||
// Branch id of funcgraph.
|
||||
// In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to
|
||||
// different places. Therefore, a branch id is created for each call node. When funcgraph is called, the branch
|
||||
// id needs to be sent to the entrance actor corresponding to the funcgraph, and then send the branch id to its
|
||||
// output switch actor.
|
||||
// different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch
|
||||
// id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id
|
||||
// to its output switch actor.
|
||||
std::unordered_map<AnfNodePtr, int> call_node_to_branch_id_;
|
||||
|
||||
std::unordered_map<AnfNodePtr, std::set<FuncGraphPtr>> call_node_to_func_graphs_;
|
||||
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
|
||||
// When initializing the weights, all related weights need to be recorded as the same device tensor.
|
||||
HostParameterToWeight host_parameter_to_weights_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> sub_front_node_to_root_front_node_;
|
||||
|
||||
// The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the
|
||||
// input of the control node.
|
||||
NodeWithDeviceContext front_value_nodes_;
|
||||
// The front value node saves all parameters that are not in the kernel graph. These nodes are generally the
|
||||
// output of subgraph, or the switch condition node.
|
||||
NodeWithDeviceContext front_parameters_;
|
||||
|
||||
// Parameters of control node which come from the host actor.
|
||||
std::vector<AnfNodePtr> control_node_parameters_;
|
||||
// The number of calls to func_graph.
|
||||
std::unordered_map<FuncGraphPtr, size_t> func_graph_to_call_num_;
|
||||
// In control flow, funcgraph will be divided into multiple kernel graphs. This map records this correspondence.
|
||||
FuncGraphToKernelGraph func_graph_to_kernel_graphs_;
|
||||
// In control flow, if there is a call node in the funcgraph, it means that when the funcgraph executes to the call,
|
||||
// it needs to jump to another funcgraph. At this time, the funcgraph needs to process other real parameters, so
|
||||
// these nodes need to send control arrows to the entrance actor to tell it to continue processing other parameters,
|
||||
// these nodes are recorded in this map.
|
||||
std::unordered_map<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_;
|
||||
// The kernel graph of call exists in the front input node.
|
||||
// In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor
|
||||
// is created for kernel graph which has a call input.
|
||||
std::unordered_map<KernelGraphPtr, DeviceContext *> call_input_kernel_graphs_;
|
||||
// The dependency between kernel and call node in auto monad.
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
|
||||
// Control nodes without a control node input in the topological sorting of funcgraph.
|
||||
std::unordered_map<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_;
|
||||
|
||||
// In heterogeneous scenario, each parameter has its own device context type, so the device context corresponding
|
||||
// to the type needs to be parsed in advance so that it can add some copy operation in the scheduler.
|
||||
// 1. The device context type of the formal parameters of funcgraph.
|
||||
std::unordered_map<FuncGraphPtr, std::vector<const DeviceContext *>> func_graph_to_device_contexts_;
|
||||
// 2. The device context type of the control node inputs.
|
||||
std::unordered_map<AnfNodePtr, std::vector<const DeviceContext *>> control_node_to_device_contexts_;
|
||||
|
||||
// Is control flow enable.
|
||||
bool is_inited_{false};
|
||||
|
||||
// Root funcgraph and its parameters.
|
||||
FuncGraphPtr root_func_graph_;
|
||||
std::vector<AnfNodePtr> root_graph_parameters_;
|
||||
|
||||
// The dependency between kernel and call node in auto monad.
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
|
||||
// Call node will call different funcgraphs according to the input partial node, and this relationship is recorded
|
||||
// in this map.
|
||||
std::unordered_map<AnfNodePtr, std::set<FuncGraphPtr>> call_node_to_func_graphs_;
|
||||
// In heterogeneous scenarios, different formal parameters of funcgraph will have different contexts. In order to
|
||||
// ensure that there is no copy actor between control actors, the device context type corresponding to each formal
|
||||
// parameter needs to be derived in the parser and recorded in this map.
|
||||
std::unordered_map<FuncGraphPtr, std::vector<const DeviceContext *>> func_graph_to_device_contexts_;
|
||||
std::unordered_map<AnfNodePtr, std::vector<const DeviceContext *>> control_node_to_device_contexts_;
|
||||
// Record which kernel graph the front node is in.
|
||||
FrontNodeToKernelGraph front_node_to_kernel_graph_;
|
||||
bool is_inited_{false};
|
||||
};
|
||||
|
||||
using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>;
|
||||
|
|
|
@ -127,7 +127,8 @@ std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCo
|
|||
|
||||
// The gather actor corresponding to a call node needs to set the branch id.
|
||||
if (AnfAlgo::IsCallNode(control_node)) {
|
||||
gather_actor->output_branch_id_ = graph_compiler_info.control_node_parser_->GetBranchIDByCallNode(control_node);
|
||||
gather_actor->output_branch_id_ =
|
||||
graph_compiler_info.control_node_parser_->FetchBranchIDByCallNode(control_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -404,7 +405,7 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
|
|||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto exit_actor = dynamic_cast<ExitActor *>(actor);
|
||||
size_t branch_id = parser->GetBranchIDByCallNode(from_node);
|
||||
size_t branch_id = parser->FetchBranchIDByCallNode(from_node);
|
||||
LinkDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second,
|
||||
branch_id);
|
||||
}
|
||||
|
|
|
@ -92,6 +92,29 @@ std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set) {
|
|||
if (actor_set->output_actor_ != nullptr) {
|
||||
(void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->output_actor_));
|
||||
}
|
||||
if (actor_set->control_actors_ != nullptr) {
|
||||
const auto &control_actor_set = actor_set->control_actors_;
|
||||
for (auto &switch_actor : control_actor_set->switch_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
(void)actors.emplace_back(static_cast<AbstractActorPtr>(switch_actor));
|
||||
}
|
||||
for (auto &gather_actor : control_actor_set->gather_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
(void)actors.emplace_back(static_cast<AbstractActorPtr>(gather_actor));
|
||||
}
|
||||
for (auto &entrance_actor : control_actor_set->entrance_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(entrance_actor);
|
||||
(void)actors.emplace_back(static_cast<AbstractActorPtr>(entrance_actor));
|
||||
}
|
||||
for (auto &exit_actor : control_actor_set->exit_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||
(void)actors.emplace_back(static_cast<AbstractActorPtr>(exit_actor));
|
||||
}
|
||||
for (auto &stack_actor : control_actor_set->stack_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||
(void)actors.emplace_back(static_cast<AbstractActorPtr>(stack_actor));
|
||||
}
|
||||
}
|
||||
|
||||
return actors;
|
||||
}
|
||||
|
@ -487,6 +510,8 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
(void)host_queue_ds_actor->data_nodes_.emplace_back(input_node);
|
||||
(void)host_queue_ds_actor->device_contexts_.emplace_back(device_context);
|
||||
(void)host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
|
||||
// In control flow, need to rely on the front node to find the location of the corresponding real parameter.
|
||||
(void)host_queue_ds_actor->data_node_position_map_.emplace(front_node, data_node_position);
|
||||
(void)front_node_position_temp_map.emplace(front_node, data_node_position);
|
||||
data_node_position++;
|
||||
}
|
||||
|
@ -525,7 +550,7 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
continue;
|
||||
}
|
||||
auto backend_iter = front_to_backend_parameter.find(parameter);
|
||||
if (backend_iter == front_to_backend_parameter.end()) {
|
||||
if (backend_iter == front_to_backend_parameter.end() || backend_iter->second.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
|
||||
}
|
||||
|
||||
|
@ -538,15 +563,20 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
(void)data_source_actors.emplace_back(host_queue_ds_actor);
|
||||
}
|
||||
|
||||
const auto &backend_node = backend_iter->second.first;
|
||||
if (host_queue_ds_actor->data_node_position_map_.find(parameter) !=
|
||||
host_queue_ds_actor->data_node_position_map_.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &backend_node = backend_iter->second.begin()->first;
|
||||
auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
|
||||
if (iter != host_queue_ds_actor->data_nodes_.end()) {
|
||||
(void)host_queue_ds_actor->data_node_position_map_.emplace(parameter,
|
||||
iter - host_queue_ds_actor->data_nodes_.begin());
|
||||
} else {
|
||||
(void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
|
||||
(void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.first);
|
||||
(void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.second);
|
||||
(void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.begin()->first);
|
||||
(void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.begin()->second);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1297,15 +1327,17 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun
|
|||
(void)no_output_actors.emplace_back(super_actor.get());
|
||||
}
|
||||
}
|
||||
for (auto &kernel_actor : actor_set->kernel_actors_) {
|
||||
// The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
|
||||
if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
|
||||
parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
|
||||
MS_LOG(INFO) << kernel_actor->kernel_->fullname_with_scope() << " is not real used by other nodes.";
|
||||
(void)no_output_actors.emplace_back(kernel_actor.get());
|
||||
|
||||
// In control flow scenario, no output actor needs to be connected to the corresponding exit actor, not loop count.
|
||||
if (!parser->IsInited()) {
|
||||
for (auto &kernel_actor : actor_set->kernel_actors_) {
|
||||
// The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
|
||||
if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) {
|
||||
(void)no_output_actors.emplace_back(kernel_actor.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &data_actor : actor_set->data_source_actors_) {
|
||||
if ((data_actor->output_data_arrows_.size() == 0) && (data_actor->output_control_arrows_.size() == 0)) {
|
||||
(void)no_output_actors.emplace_back(data_actor.get());
|
||||
|
@ -1332,7 +1364,9 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun
|
|||
|
||||
void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
|
||||
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep ||
|
||||
(graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited())) {
|
||||
// In control flow, the exit actor of the root graph sends output data to the output actor.
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
@ -1706,6 +1740,7 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf
|
|||
DumpCopyActors(actor_set->copy_actors_, ofs);
|
||||
DumpLoopCountActor(actor_set->loop_count_actor_, ofs);
|
||||
DumpOutputActor(actor_set->output_actor_, ofs);
|
||||
DumpControlActors(actor_set->control_actors_, ofs);
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
|
||||
|
|
|
@ -379,6 +379,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
|||
|
||||
// Compile root graph.
|
||||
graph_id_to_device_context_.clear();
|
||||
func_graph_to_kernel_graph_ids_.clear();
|
||||
control_nodes_.clear();
|
||||
auto subgraph_need_compile = CompileGraph(root_graph);
|
||||
|
||||
|
@ -476,6 +477,10 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu
|
|||
}
|
||||
|
||||
graph_id_to_device_context_[graph_id] = device_context;
|
||||
|
||||
const auto &func_graph = segment->nodes_[0]->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph_to_kernel_graph_ids_[func_graph].emplace_back(graph_id);
|
||||
} else {
|
||||
// Compile the cut node.
|
||||
auto cut_node = segment->nodes_[0];
|
||||
|
@ -971,8 +976,18 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
(void)name.append("_").append(std::to_string(graph_id_to_context.first));
|
||||
}
|
||||
|
||||
FuncGraphToKernelGraph func_graph_to_kernel_graphs;
|
||||
for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
|
||||
const auto &func_graph = func_graph_to_kernel_graph_ids.first;
|
||||
for (const auto &graph_id : func_graph_to_kernel_graph_ids.second) {
|
||||
const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graph);
|
||||
}
|
||||
}
|
||||
|
||||
auto parser = std::make_shared<ControlNodeParser>();
|
||||
parser->Parse(control_nodes_, graphs, device_contexts, root_graph);
|
||||
parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs);
|
||||
|
||||
runtime::KernelMapPosition outputs_order;
|
||||
size_t outputs_num = 0;
|
||||
|
|
|
@ -42,6 +42,7 @@ using ActorInfo = runtime::ActorInfo;
|
|||
using GraphCompiler = runtime::GraphCompiler;
|
||||
using GraphCompilerInfo = runtime::GraphCompilerInfo;
|
||||
using ControlNodeParser = runtime::ControlNodeParser;
|
||||
using FuncGraphToKernelGraph = runtime::FuncGraphToKernelGraph;
|
||||
using ControlNodeParserPtr = runtime::ControlNodeParserPtr;
|
||||
using KernelWithIndex = session::KernelWithIndex;
|
||||
|
||||
|
@ -157,6 +158,8 @@ class MindRTBackend : public Backend {
|
|||
// node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to
|
||||
// the corresponding device_context.
|
||||
std::map<GraphId, DeviceContext *> graph_id_to_device_context_;
|
||||
// Funcgraph will be cut into multiple kernel graphs, and the map is used to save the correspondence.
|
||||
std::map<FuncGraphPtr, std::vector<GraphId>> func_graph_to_kernel_graph_ids_;
|
||||
std::map<GraphInfo, DeviceContext *> graph_info_to_device_context_;
|
||||
std::vector<AnfNodePtr> control_nodes_;
|
||||
|
||||
|
|
Loading…
Reference in New Issue