!18895 Add dump interface and link control arrow for switch actor.

Merge pull request !18895 from gaoyong10/new_runtime11
This commit is contained in:
i-robot 2021-06-26 08:08:58 +00:00 committed by Gitee
commit d1fe65892e
3 changed files with 188 additions and 1 deletions

View File

@ -21,6 +21,7 @@
#include <string>
#include <memory>
#include <utility>
#include <stack>
#include <unordered_map>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/device_tensor_store.h"
@ -103,18 +104,35 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
// The position of the branch output in the input_nodes_.
std::vector<std::vector<size_t>> branch_inputs_pos_;
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
std::unordered_map<uuids::uuid *, std::unordered_map<AID *, size_t>> input_controls_;
// Branch ids is used to record the id corresponding to the switch output branch.
// In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
// places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send
// its branch id to the gather of the subgraph. Then branch id will be sent by the gather actor to the switch
// actor connected to the output.
// In a recursive scenario, the switch will sequentially receive the branch ids sent by the caller, and the switch
// actor needs to store the branch ids in the stack, and pop up in turn when returning.
std::unordered_map<uuids::uuid *, std::stack<int>> input_branch_ids_;
// Control arrows of different branches.
std::vector<std::vector<AID>> output_branch_control_arrows_;
// Branch id arrows of different branches.
std::vector<std::vector<AID>> output_branch_branch_arrows_;
// Result arrows of different branches.
std::vector<std::vector<DataArrowPtr>> output_branch_result_arrows_;
// When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor,
// so all the nodes that may send the device tensor to switch actor are recorded.
std::vector<std::vector<KernelWithIndex>> front_to_backend_parameter_;
std::vector<std::vector<KernelWithIndex>> backend_parameters_;
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
std::vector<FuncGraphPtr> branch_func_graph_;
std::unordered_map<int, size_t> branch_id_to_index_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
@ -130,6 +148,8 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
// The dependent input controls number.
size_t input_controls_num_{0};
CNodePtr node_;
int local_branch_id_;
size_t input_branch_id_num_;
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;

View File

@ -2017,6 +2017,81 @@ void GraphScheduler::LinkControlArrowForGatherActor(std::vector<GatherActorPtr>
}
}
void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *switch_actors,
LoopCountActor *to_actor,
const KernelMapPosition &origin_outputs_order) {
if (to_actor == nullptr || (*switch_actors).empty()) {
return;
}
for (auto &switch_actor : (*switch_actors)) {
for (size_t i = 0; i < switch_actor->output_branch_arrows_.size(); ++i) {
const auto &arrows = switch_actor->output_branch_arrows_[i];
if (arrows.empty() && switch_actor->branch_func_graph_[i] != nullptr) {
const auto &actor_name = switch_actor->branch_func_graph_[i]->ToString();
const auto &actor = FetchActor(actor_name);
if (actor != nullptr) {
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID());
gather_actor->input_controls_num_++;
}
}
}
}
std::set<AnfNodePtr> call_nodes;
for (const auto &output : origin_outputs_order) {
if (IsCallNode(output.first.first)) {
call_nodes.insert(output.first.first);
}
}
to_actor->branch_id_to_input_controls_num_[kMainBranchID] += call_nodes.size();
for (const auto &call_node : call_nodes) {
const auto &func_graphs = FetchFuncGraphbyCallNode(call_node->cast<CNodePtr>());
for (const auto func_graph : func_graphs) {
MS_EXCEPTION_IF_NULL(func_graph);
const auto &actor_name = func_graph->get_return()->DebugString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
size_t branch_index = switch_actor->branch_id_to_index_.size();
if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
branch_index = switch_actor->branch_id_to_index_[kMainBranchID];
} else {
switch_actor->branch_id_to_index_[kMainBranchID] = branch_index;
}
switch_actor->output_branch_control_arrows_[branch_index].emplace_back(to_actor->GetAID());
}
}
}
void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
const ActorSet *actor_set) {
for (const auto &control_node : graph_compiler_info.control_nodes_) {
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
const auto &actor_name = control_node->DebugString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
const auto &func_graph = switch_actor->branch_func_graph_[i];
if (func_graph == nullptr || func_graph->output()->isa<ValueNode>()) {
continue;
}
const auto &gather_actor = FetchActor(func_graph->ToString());
MS_EXCEPTION_IF_NULL(gather_actor);
switch_actor->output_branch_branch_arrows_[i].emplace_back(gather_actor->GetAID());
}
}
}
}
void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info,
const ActorSet *actor_set) {
if (graph_compiler_info.control_nodes_.empty()) {
@ -2571,5 +2646,92 @@ void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compil
ofs << "\n";
}
}
void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\t\tactor input num:" << actor->data_nodes_.size() << "\n";
for (const auto &node : actor->data_nodes_) {
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node) << '\n';
}
ofs << "\t\tactor front to backend node:\n";
for (const auto &front_to_backend_parameter : actor->front_to_backend_parameter_) {
ofs << "\t\t\tfront node:" << AnfAlgo::GetNodeDebugString(front_to_backend_parameter.first) << '\n';
for (const auto node_with_index : front_to_backend_parameter.second) {
ofs << "\t\t\t\tbackend node:" << AnfAlgo::GetNodeDebugString(node_with_index.first)
<< "\tindex:" << node_with_index.second << '\n';
}
}
ofs << "\t\tactor output data arrow:\n";
for (const auto &data_arrow : actor->output_data_arrows_) {
MS_EXCEPTION_IF_NULL(data_arrow);
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
<< "\n";
}
ofs << "\t\tactor output result arrow:\n";
for (const auto &result_arrow : actor->output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name() << "\tto_input_index:" << result_arrow->to_input_index_
<< "\n";
}
ofs << "\t\tactor output control arrow:\n";
for (const auto &control_arrow : actor->output_control_arrows_) {
ofs << "\t\t\tto_actor_name:" << control_arrow;
}
}
void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\t\tactor input num:" << actor->input_nodes_.size() << "\n";
for (const auto &node : actor->input_nodes_) {
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node) << '\n';
}
ofs << "\t\tactor input pos:\n";
for (size_t i = 0; i < actor->branch_inputs_pos_.size(); ++i) {
ofs << "\t\t\tbranch " << i << " input pos:";
for (const auto pos : actor->branch_inputs_pos_[i]) {
ofs << pos << '\t';
}
ofs << '\n';
}
ofs << "\t\tactor output data arrow:\n";
for (size_t i = 0; i < actor->output_branch_arrows_.size(); ++i) {
ofs << "\t\t\tbranch " << i << " output data:\n";
for (const auto arrow : actor->output_branch_arrows_[i]) {
MS_EXCEPTION_IF_NULL(arrow);
ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
<< "\tto_input_index:" << arrow->to_input_index_ << '\n';
}
}
ofs << "\t\tactor output result arrow:\n";
for (size_t i = 0; i < actor->output_branch_result_arrows_.size(); ++i) {
ofs << "\t\t\tbranch " << i << " output result:\n";
for (const auto arrow : actor->output_branch_result_arrows_[i]) {
MS_EXCEPTION_IF_NULL(arrow);
ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
<< "\tto_input_index:" << arrow->to_input_index_ << '\n';
}
}
ofs << "\t\tactor output control arrow:\n";
for (size_t i = 0; i < actor->output_branch_control_arrows_.size(); ++i) {
ofs << "\t\t\tbranch " << i << " output control:\n";
for (const auto arrow : actor->output_branch_control_arrows_[i]) {
ofs << "\t\t\t\t from index:" << arrow << '\n';
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -235,8 +235,11 @@ class GraphScheduler {
void LinkDataArrowForSwitchActor(SwitchActor *from_actor, KernelActor *to_actor, const size_t to_index);
void LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors, LoopCountActor *to_actor,
const std::vector<KernelGraphPtr> &graphs);
void LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *switch_actors, LoopCountActor *to_actor,
const KernelMapPosition &origin_outputs_order);
// In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to
// send the branch id to the loop count actor.
void LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
@ -277,6 +280,8 @@ class GraphScheduler {
void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
void DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const;
void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const;
void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const;
void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const;
void DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const;
// The global maps, only be cleared in the deconstruction.