!19137 link auto monad for gather actor.

Merge pull request !19137 from gaoyong10/new_runtime17
This commit is contained in:
i-robot 2021-06-30 18:57:09 +00:00 committed by Gitee
commit 26d1157d00
5 changed files with 73 additions and 13 deletions

View File

@ -1142,8 +1142,15 @@ FuncGraphPtr KernelGraph::GetFuncGraph() {
if (front_backend_anf_map_.empty()) {
return nullptr;
}
const auto &front_node = front_backend_anf_map_.begin()->first;
return front_node->func_graph();
for (const auto &front_backend_anf : front_backend_anf_map_) {
const auto &front_node = front_backend_anf.first;
const auto &func_graph = front_node->func_graph();
if (func_graph != nullptr) {
return func_graph;
}
}
return nullptr;
}
void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output,

View File

@ -396,6 +396,20 @@ std::vector<AnfNodePtr> FetchOutputBySwitchNode(const AnfNodePtr &switch_node, s
return outputs;
}
// Recursive interface, get the real kernel that UpdateState node depends on.
AnfNodePtr FetchSourceNodeByAutoMonad(const AnfNodePtr &node) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
const auto &cnode = node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (inputs.size() <= kUpdateStateRealInput) {
MS_LOG(EXCEPTION) << "Invalid updatestate node:" << AnfAlgo::GetNodeDebugString(node);
}
return FetchSourceNodeByAutoMonad(inputs[kUpdateStateRealInput]);
}
return node;
}
} // namespace
// Return true if the node has Ref abstract.
@ -615,6 +629,8 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
FetchCallInputKernelGraph(graphs, device_contexts);
FetchBackendInputNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters);
FetchAutoMonadNode(control_nodes);
}
std::vector<KernelWithIndex> ControlNodeParser::GetBackendInputByParameter(const AnfNodePtr &parameter) {
@ -1370,5 +1386,27 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0});
}
}
void ControlNodeParser::FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes) {
for (const auto &control_node : control_nodes) {
const auto &cnode = control_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(control_node);
}
if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState)) {
const auto &node = FetchSourceNodeByAutoMonad(inputs[i]);
const auto &iter = front_to_backend_kernels_.find(node);
if (iter != front_to_backend_kernels_.end()) {
kernel_to_call_nodes_[iter->second.first] = control_node;
}
}
}
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -181,6 +181,8 @@ class ControlNodeParser {
void FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, std::set<AnfNodePtr> *call_nodes,
std::set<AnfNodePtr> *switch_nodes, std::set<KernelWithIndex> *results);
// Get the dependency between kernel and call node in auto monad.
void FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes);
// The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
FrontToBackendNodeWithContext front_to_backend_parameters_;
@ -226,6 +228,9 @@ class ControlNodeParser {
// Root funcgraph and its parameters.
FuncGraphPtr root_func_graph_;
std::vector<AnfNodePtr> root_graph_parameters_;
// The dependency between kernel and call node in auto monad.
std::unordered_map<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
};
using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>;

View File

@ -1953,8 +1953,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
LinkBranchArrowForGatherActor(graph_compiler_info, actor_set);
LinkControlArrowForGatherActor(&(actor_set->gather_actors_), &(actor_set->kernel_actors_),
actor_set->loop_count_actor_.get(), graph_compiler_info.graphs_,
LinkControlArrowForGatherActor(&(actor_set->kernel_actors_), graph_compiler_info.graphs_,
graph_compiler_info.control_node_parser_);
LinkControlArrowForSwitchActor(&(actor_set->switch_actors_), actor_set->loop_count_actor_.get(),
@ -2163,14 +2162,9 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
}
}
void GraphScheduler::LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors,
std::vector<KernelActorPtr> *kernel_actors,
LoopCountActor *to_actor, const std::vector<KernelGraphPtr> &graphs,
void GraphScheduler::LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *kernel_actors,
const std::vector<KernelGraphPtr> &graphs,
const ControlNodeParserPtr &parser) {
if (from_actors == nullptr || to_actor == nullptr) {
return;
}
// Link control arrow to kernel actor.
for (size_t i = 0; i < graphs.size(); ++i) {
const auto &kernel_graph = graphs[i];
@ -2226,6 +2220,22 @@ void GraphScheduler::LinkControlArrowForGatherActor(std::vector<GatherActorPtr>
}
}
}
// Link input auto monad control arrow from kernel actor to gather actor.
const auto &monad_nodes = parser->kernel_to_call_nodes_;
for (const auto node_pair : monad_nodes) {
const auto &kernel_actor_name = node_pair.first->fullname_with_scope();
const auto &gather_actor_name = node_pair.second->DebugString();
auto kernel_op_actor = FetchActor(kernel_actor_name);
auto gather_op_actor = FetchActor(gather_actor_name);
if (kernel_op_actor == nullptr || gather_op_actor == nullptr) {
continue;
}
auto kernel_actor = dynamic_cast<KernelActor *>(kernel_op_actor);
auto gather_actor = dynamic_cast<GatherActor *>(gather_op_actor);
kernel_actor->output_control_arrows_.emplace_back(gather_actor->GetAID());
gather_actor->input_controls_num_++;
}
}
void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *switch_actors,

View File

@ -242,8 +242,8 @@ class GraphScheduler {
const size_t to_index);
void LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index, OpActor<DeviceTensor> *to_actor,
const size_t to_index, const size_t branch_index = SIZE_MAX);
void LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors,
std::vector<KernelActorPtr> *kernel_actors, LoopCountActor *to_actor,
void LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *kernel_actors,
const std::vector<KernelGraphPtr> &graphs, const ControlNodeParserPtr &parser);
void LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *switch_actors, LoopCountActor *to_actor,