forked from mindspore-Ecosystem/mindspore
!18895 Add dump interface and link control arrow for switch actor.
Merge pull request !18895 from gaoyong10/new_runtime11
This commit is contained in:
commit
d1fe65892e
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <stack>
|
||||
#include <unordered_map>
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
|
@ -103,18 +104,35 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
// The position of the branch output in the input_nodes_.
|
||||
std::vector<std::vector<size_t>> branch_inputs_pos_;
|
||||
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
|
||||
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<AID *, size_t>> input_controls_;
|
||||
|
||||
// Branch ids is used to record the id corresponding to the switch output branch.
|
||||
// In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
|
||||
// places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send
|
||||
// its branch id to the gather of the subgraph. Then branch id will be sent by the gather actor to the switch
|
||||
// actor connected to the output.
|
||||
// In a recursive scenario, the switch will sequentially receive the branch ids sent by the caller, and the switch
|
||||
// actor needs to store the branch ids in the stack, and pop up in turn when returning.
|
||||
std::unordered_map<uuids::uuid *, std::stack<int>> input_branch_ids_;
|
||||
|
||||
// Control arrows of different branches.
|
||||
std::vector<std::vector<AID>> output_branch_control_arrows_;
|
||||
// Branch id arrows of different branches.
|
||||
std::vector<std::vector<AID>> output_branch_branch_arrows_;
|
||||
// Result arrows of different branches.
|
||||
std::vector<std::vector<DataArrowPtr>> output_branch_result_arrows_;
|
||||
|
||||
// When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor,
|
||||
// so all the nodes that may send the device tensor to switch actor are recorded.
|
||||
std::vector<std::vector<KernelWithIndex>> front_to_backend_parameter_;
|
||||
|
||||
std::vector<std::vector<KernelWithIndex>> backend_parameters_;
|
||||
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
|
||||
std::vector<FuncGraphPtr> branch_func_graph_;
|
||||
|
||||
std::unordered_map<int, size_t> branch_id_to_index_;
|
||||
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
|
||||
|
||||
|
@ -130,6 +148,8 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
// The dependent input controls number.
|
||||
size_t input_controls_num_{0};
|
||||
CNodePtr node_;
|
||||
int local_branch_id_;
|
||||
size_t input_branch_id_num_;
|
||||
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
|
||||
|
|
|
@ -2017,6 +2017,81 @@ void GraphScheduler::LinkControlArrowForGatherActor(std::vector<GatherActorPtr>
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *switch_actors,
|
||||
LoopCountActor *to_actor,
|
||||
const KernelMapPosition &origin_outputs_order) {
|
||||
if (to_actor == nullptr || (*switch_actors).empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto &switch_actor : (*switch_actors)) {
|
||||
for (size_t i = 0; i < switch_actor->output_branch_arrows_.size(); ++i) {
|
||||
const auto &arrows = switch_actor->output_branch_arrows_[i];
|
||||
if (arrows.empty() && switch_actor->branch_func_graph_[i] != nullptr) {
|
||||
const auto &actor_name = switch_actor->branch_func_graph_[i]->ToString();
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
if (actor != nullptr) {
|
||||
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID());
|
||||
gather_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::set<AnfNodePtr> call_nodes;
|
||||
for (const auto &output : origin_outputs_order) {
|
||||
if (IsCallNode(output.first.first)) {
|
||||
call_nodes.insert(output.first.first);
|
||||
}
|
||||
}
|
||||
|
||||
to_actor->branch_id_to_input_controls_num_[kMainBranchID] += call_nodes.size();
|
||||
|
||||
for (const auto &call_node : call_nodes) {
|
||||
const auto &func_graphs = FetchFuncGraphbyCallNode(call_node->cast<CNodePtr>());
|
||||
for (const auto func_graph : func_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &actor_name = func_graph->get_return()->DebugString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
|
||||
size_t branch_index = switch_actor->branch_id_to_index_.size();
|
||||
if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
|
||||
branch_index = switch_actor->branch_id_to_index_[kMainBranchID];
|
||||
} else {
|
||||
switch_actor->branch_id_to_index_[kMainBranchID] = branch_index;
|
||||
}
|
||||
|
||||
switch_actor->output_branch_control_arrows_[branch_index].emplace_back(to_actor->GetAID());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
const ActorSet *actor_set) {
|
||||
for (const auto &control_node : graph_compiler_info.control_nodes_) {
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
|
||||
const auto &actor_name = control_node->DebugString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
|
||||
const auto &func_graph = switch_actor->branch_func_graph_[i];
|
||||
if (func_graph == nullptr || func_graph->output()->isa<ValueNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &gather_actor = FetchActor(func_graph->ToString());
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
switch_actor->output_branch_branch_arrows_[i].emplace_back(gather_actor->GetAID());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
const ActorSet *actor_set) {
|
||||
if (graph_compiler_info.control_nodes_.empty()) {
|
||||
|
@ -2571,5 +2646,92 @@ void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compil
|
|||
ofs << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
|
||||
ofs << "\t\tactor input num:" << actor->data_nodes_.size() << "\n";
|
||||
for (const auto &node : actor->data_nodes_) {
|
||||
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node) << '\n';
|
||||
}
|
||||
|
||||
ofs << "\t\tactor front to backend node:\n";
|
||||
for (const auto &front_to_backend_parameter : actor->front_to_backend_parameter_) {
|
||||
ofs << "\t\t\tfront node:" << AnfAlgo::GetNodeDebugString(front_to_backend_parameter.first) << '\n';
|
||||
for (const auto node_with_index : front_to_backend_parameter.second) {
|
||||
ofs << "\t\t\t\tbackend node:" << AnfAlgo::GetNodeDebugString(node_with_index.first)
|
||||
<< "\tindex:" << node_with_index.second << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output data arrow:\n";
|
||||
for (const auto &data_arrow : actor->output_data_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output result arrow:\n";
|
||||
for (const auto &result_arrow : actor->output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name() << "\tto_input_index:" << result_arrow->to_input_index_
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output control arrow:\n";
|
||||
for (const auto &control_arrow : actor->output_control_arrows_) {
|
||||
ofs << "\t\t\tto_actor_name:" << control_arrow;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
|
||||
ofs << "\t\tactor input num:" << actor->input_nodes_.size() << "\n";
|
||||
for (const auto &node : actor->input_nodes_) {
|
||||
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node) << '\n';
|
||||
}
|
||||
|
||||
ofs << "\t\tactor input pos:\n";
|
||||
for (size_t i = 0; i < actor->branch_inputs_pos_.size(); ++i) {
|
||||
ofs << "\t\t\tbranch " << i << " input pos:";
|
||||
for (const auto pos : actor->branch_inputs_pos_[i]) {
|
||||
ofs << pos << '\t';
|
||||
}
|
||||
ofs << '\n';
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output data arrow:\n";
|
||||
for (size_t i = 0; i < actor->output_branch_arrows_.size(); ++i) {
|
||||
ofs << "\t\t\tbranch " << i << " output data:\n";
|
||||
for (const auto arrow : actor->output_branch_arrows_[i]) {
|
||||
MS_EXCEPTION_IF_NULL(arrow);
|
||||
ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
|
||||
<< "\tto_input_index:" << arrow->to_input_index_ << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output result arrow:\n";
|
||||
for (size_t i = 0; i < actor->output_branch_result_arrows_.size(); ++i) {
|
||||
ofs << "\t\t\tbranch " << i << " output result:\n";
|
||||
for (const auto arrow : actor->output_branch_result_arrows_[i]) {
|
||||
MS_EXCEPTION_IF_NULL(arrow);
|
||||
ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
|
||||
<< "\tto_input_index:" << arrow->to_input_index_ << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
ofs << "\t\tactor output control arrow:\n";
|
||||
for (size_t i = 0; i < actor->output_branch_control_arrows_.size(); ++i) {
|
||||
ofs << "\t\t\tbranch " << i << " output control:\n";
|
||||
for (const auto arrow : actor->output_branch_control_arrows_[i]) {
|
||||
ofs << "\t\t\t\t from index:" << arrow << '\n';
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -235,8 +235,11 @@ class GraphScheduler {
|
|||
void LinkDataArrowForSwitchActor(SwitchActor *from_actor, KernelActor *to_actor, const size_t to_index);
|
||||
void LinkControlArrowForGatherActor(std::vector<GatherActorPtr> *from_actors, LoopCountActor *to_actor,
|
||||
const std::vector<KernelGraphPtr> &graphs);
|
||||
void LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *switch_actors, LoopCountActor *to_actor,
|
||||
const KernelMapPosition &origin_outputs_order);
|
||||
// In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to
|
||||
// send the branch id to the loop count actor.
|
||||
void LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
|
@ -277,6 +280,8 @@ class GraphScheduler {
|
|||
void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
|
||||
void DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const;
|
||||
void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const;
|
||||
void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const;
|
||||
void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const;
|
||||
void DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const;
|
||||
|
||||
// The global maps, only be cleared in the deconstruction.
|
||||
|
|
Loading…
Reference in New Issue