!19137 link auto monad for gather actor.
Merge pull request !19137 from gaoyong10/new_runtime17
This commit is contained in:
commit
26d1157d00
|
@ -1142,8 +1142,15 @@ FuncGraphPtr KernelGraph::GetFuncGraph() {
|
|||
if (front_backend_anf_map_.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto &front_node = front_backend_anf_map_.begin()->first;
|
||||
return front_node->func_graph();
|
||||
|
||||
for (const auto &front_backend_anf : front_backend_anf_map_) {
|
||||
const auto &front_node = front_backend_anf.first;
|
||||
const auto &func_graph = front_node->func_graph();
|
||||
if (func_graph != nullptr) {
|
||||
return func_graph;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output,
|
||||
|
|
|
@ -396,6 +396,20 @@ std::vector<AnfNodePtr> FetchOutputBySwitchNode(const AnfNodePtr &switch_node, s
|
|||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
// Recursive interface, get the real kernel that UpdateState node depends on.
|
||||
AnfNodePtr FetchSourceNodeByAutoMonad(const AnfNodePtr &node) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() <= kUpdateStateRealInput) {
|
||||
MS_LOG(EXCEPTION) << "Invalid updatestate node:" << AnfAlgo::GetNodeDebugString(node);
|
||||
}
|
||||
|
||||
return FetchSourceNodeByAutoMonad(inputs[kUpdateStateRealInput]);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Return true if the node has Ref abstract.
|
||||
|
@ -615,6 +629,8 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
|||
FetchCallInputKernelGraph(graphs, device_contexts);
|
||||
|
||||
FetchBackendInputNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters);
|
||||
|
||||
FetchAutoMonadNode(control_nodes);
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> ControlNodeParser::GetBackendInputByParameter(const AnfNodePtr ¶meter) {
|
||||
|
@ -1370,5 +1386,27 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
|
|||
formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0});
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
for (const auto &control_node : control_nodes) {
|
||||
const auto &cnode = control_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(control_node);
|
||||
}
|
||||
|
||||
if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState)) {
|
||||
const auto &node = FetchSourceNodeByAutoMonad(inputs[i]);
|
||||
const auto &iter = front_to_backend_kernels_.find(node);
|
||||
if (iter != front_to_backend_kernels_.end()) {
|
||||
kernel_to_call_nodes_[iter->second.first] = control_node;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -181,6 +181,8 @@ class ControlNodeParser {
|
|||
void FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, std::set<AnfNodePtr> *call_nodes,
|
||||
std::set<AnfNodePtr> *switch_nodes, std::set<KernelWithIndex> *results);
|
||||
|
||||
// Get the dependency between kernel and call node in auto monad.
|
||||
void FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
|
||||
FrontToBackendNodeWithContext front_to_backend_parameters_;
|
||||
|
||||
|
@ -226,6 +228,9 @@ class ControlNodeParser {
|
|||
// Root funcgraph and its parameters.
|
||||
FuncGraphPtr root_func_graph_;
|
||||
std::vector<AnfNodePtr> root_graph_parameters_;
|
||||
|
||||
// The dependency between kernel and call node in auto monad.
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
|
||||
};
|
||||
|
||||
using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>;
|
||||
|
|
|
@ -1953,8 +1953,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
|
|||
|
||||
LinkBranchArrowForGatherActor(graph_compiler_info, actor_set);
|
||||
|
||||
LinkControlArrowForGatherActor(&(actor_set->gather_actors_), &(actor_set->kernel_actors_),
|
||||
actor_set->loop_count_actor_.get(), graph_compiler_info.graphs_,
|
||||
LinkControlArrowForGatherActor(&(actor_set->kernel_actors_), graph_compiler_info.graphs_,
|
||||
graph_compiler_info.control_node_parser_);
|
||||
|
||||
LinkControlArrowForSwitchActor(&(actor_set->switch_actors_), actor_set->loop_count_actor_.get(),
|
||||
|
@ -2163,14 +2162,9 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors,
|
||||
std::vector<KernelActorPtr> *kernel_actors,
|
||||
LoopCountActor *to_actor, const std::vector<KernelGraphPtr> &graphs,
|
||||
void GraphScheduler::LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *kernel_actors,
|
||||
const std::vector<KernelGraphPtr> &graphs,
|
||||
const ControlNodeParserPtr &parser) {
|
||||
if (from_actors == nullptr || to_actor == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Link control arrow to kernel actor.
|
||||
for (size_t i = 0; i < graphs.size(); ++i) {
|
||||
const auto &kernel_graph = graphs[i];
|
||||
|
@ -2226,6 +2220,22 @@ void GraphScheduler::LinkControlArrowForGatherActor(std::vector<GatherActorPtr>
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Link input auto monad control arrow from kernel actor to gather actor.
|
||||
const auto &monad_nodes = parser->kernel_to_call_nodes_;
|
||||
for (const auto node_pair : monad_nodes) {
|
||||
const auto &kernel_actor_name = node_pair.first->fullname_with_scope();
|
||||
const auto &gather_actor_name = node_pair.second->DebugString();
|
||||
auto kernel_op_actor = FetchActor(kernel_actor_name);
|
||||
auto gather_op_actor = FetchActor(gather_actor_name);
|
||||
if (kernel_op_actor == nullptr || gather_op_actor == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto kernel_actor = dynamic_cast<KernelActor *>(kernel_op_actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(gather_op_actor);
|
||||
kernel_actor->output_control_arrows_.emplace_back(gather_actor->GetAID());
|
||||
gather_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *switch_actors,
|
||||
|
|
|
@ -242,8 +242,8 @@ class GraphScheduler {
|
|||
const size_t to_index);
|
||||
void LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index, OpActor<DeviceTensor> *to_actor,
|
||||
const size_t to_index, const size_t branch_index = SIZE_MAX);
|
||||
void LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors,
|
||||
std::vector<KernelActorPtr> *kernel_actors, LoopCountActor *to_actor,
|
||||
|
||||
void LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *kernel_actors,
|
||||
const std::vector<KernelGraphPtr> &graphs, const ControlNodeParserPtr &parser);
|
||||
|
||||
void LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *switch_actors, LoopCountActor *to_actor,
|
||||
|
|
Loading…
Reference in New Issue