fix control flow bugs.

This commit is contained in:
gaoyong10 2021-07-09 15:23:02 +08:00
parent 2c94b631a7
commit 75ab763696
6 changed files with 75 additions and 224 deletions

View File

@ -99,45 +99,14 @@ void SwitchActor::ParsePartialInput(const AnfNodePtr &node, const size_t branch_
}
auto func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
if (func_graph->output()->isa<ValueNode>()) {
AddInput(func_graph->output(), branch_id);
return;
} else if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
// If the funcgraph called by the partial returns a partial node, the switch actor should call the funcgraph
// of the sub partial. Similarly, the input node should also be the input of the sub partial.
is_mulit_call_ = true;
CNodePtr sub_partial = func_graph->output()->cast<CNodePtr>();
const auto &sub_partial_inputs = sub_partial->inputs();
if (sub_partial_inputs.size() <= kPartialFuncGraphPos) {
MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(sub_partial);
}
const auto &sub_func_graph = GetValueNode<FuncGraphPtr>(sub_partial_inputs[kPartialFuncGraphPos]);
if (sub_func_graph->output()->isa<ValueNode>()) {
AddInput(sub_func_graph->output(), branch_id);
return;
}
branch_func_graph_[branch_id] = sub_func_graph;
const auto &sub_parameters = func_graph->parameters();
// Record the input that comes with the sub partial node.
for (size_t i = kPartialInputStartPos; i < sub_partial_inputs.size(); ++i) {
const auto &real_partial_input = AnfAlgo::VisitKernelWithReturnType(sub_partial_inputs[i], 0).first;
const auto &iter = find(sub_parameters.begin(), sub_parameters.end(), real_partial_input);
if ((iter != sub_parameters.end()) &&
((iter - sub_parameters.begin()) < SizeToInt(partial_inputs.size() - kPartialInputStartPos))) {
AddInput(partial_inputs[iter - sub_parameters.begin() + kPartialInputStartPos], branch_id);
}
}
return;
}
branch_func_graph_[branch_id] = func_graph;
for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
AddInput(partial_inputs[j], branch_id);
}
} else if (IsValueNode<FuncGraph>(node)) {
const auto func_graph = GetValueNode<FuncGraphPtr>(node);
branch_func_graph_[branch_id] = func_graph;
} else {
AddInput(node, branch_id);
}
@ -158,11 +127,6 @@ void SwitchActor::ParseReturnInput(const ControlNodeParserPtr &parser) {
const auto &call_num = parser->GetCallNumByFuncGraph(func_graph);
InitVectorSize(call_num);
// If the return is a partial node or funcgraph, this subgraph will not be initialized and no input is required.
if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial) ||
(func_graph->output()->isa<ValueNode>() && IsValueNode<FuncGraph>(func_graph->output()))) {
return;
}
AddCommonInput(func_graph->output());
}
@ -210,26 +174,7 @@ void SwitchActor::ParseSwitchLayerInput() {
if (AnfAlgo::CheckPrimitiveType(branch_nodes[i], prim::kPrimPartial)) {
ParsePartialInput(branch_nodes[i], i - kMakeTupleInputStartPos);
} else if (branch_nodes[i]->isa<ValueNode>()) {
const auto &func_graph = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
const auto output = func_graph->output();
// The switch layer node has a second-order call connected to call. When the called funcgraph returns a partial
// node or funcgraph, the switch actor needs to call the funcgraph directly.
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
is_mulit_call_ = true;
branch_func_graph_[i - kMakeTupleInputStartPos] =
GetValueNode<FuncGraphPtr>(output->cast<CNodePtr>()->input(kPartialFuncGraphPos));
} else if (output->isa<ValueNode>() && IsValueNode<FuncGraph>(output)) {
is_mulit_call_ = true;
const auto &sub_func_graph = GetValueNode<FuncGraphPtr>(output);
if (sub_func_graph->output()->isa<ValueNode>()) {
AddInput(sub_func_graph->output(), i - kMakeTupleInputStartPos);
continue;
}
branch_func_graph_[i - kMakeTupleInputStartPos] = GetValueNode<FuncGraphPtr>(output);
} else {
branch_func_graph_[i - kMakeTupleInputStartPos] = func_graph;
}
branch_func_graph_[i - 1] = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
}
}
}
@ -256,9 +201,7 @@ void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t b
// The value node and weight node need to be placed in the device store. The switch actor has three inputs:
// 1) The input of the switch is the value node.
// 2) There is a weight node or value node in the return of the sub funcgraph.
// 3) When the switch actor is a second-order call, it does not distinguish between weight and parameter.
if (((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) || is_mulit_call_) && node->isa<Parameter>() &&
HasAbstractRef(node)) ||
if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && node->isa<Parameter>() && HasAbstractRef(node)) ||
node->isa<ValueNode>()) {
const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
if (iter != input_nodes_.end()) {
@ -308,6 +251,12 @@ void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
for (size_t i = 0; i < call_output_num; ++i) {
AddInput({real_input.first, i}, branch);
}
} else if (real_input.first->isa<ValueNode>() && real_input.first->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
const auto &value = real_input.first->cast<ValueNodePtr>()->value();
const auto &tuple_value = value->cast<ValueTuplePtr>();
for (size_t i = 0; i < tuple_value->value().size(); ++i) {
AddInput({real_input.first, i}, branch);
}
} else {
AddInput(real_input, branch);
}

View File

@ -171,11 +171,6 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
// Used to indicate that in the control flow, when the input of the call node is a call node, the switch actor
// corresponding to the switch node called by the sub call node. At this time, the funcgraph of the input of
// the switch actor will return to a partial node or funcgraph.
bool is_mulit_call_{false};
};
using SwitchActorPtr = std::shared_ptr<SwitchActor>;

View File

@ -22,7 +22,6 @@
namespace mindspore {
namespace runtime {
constexpr size_t kSingleCallDepth = 1;
namespace {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
// Fetch all the weight parameters related to node. It runs like this:
@ -69,6 +68,9 @@ void FetchParameterBySwitchNode(const AnfNodePtr &switch_node, FuncGraphToParame
for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
const auto &partial_node = switch_inputs[i];
if (IsValueNode<FuncGraph>(partial_node)) {
continue;
}
const auto &func_graph = GetFuncGraphFromPartial(partial_node);
std::vector<AnfNodePtr> parameters;
const auto &partial_inputs = partial_node->cast<CNodePtr>()->inputs();
@ -451,40 +453,6 @@ std::vector<AnfNodePtr> FetchParameterByControlNode(const std::vector<AnfNodePtr
return parameters;
}
// Get the number of calls, that is, the number of times the input of the call node is the call node.
size_t FetchCallDepth(const AnfNodePtr &node) {
if (!IsCallNode(node)) {
return 0;
}
const auto &cnode = node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
return kSingleCallDepth + FetchCallDepth(inputs[0]);
}
// Get the final subgraph called by fungraph through the depth of calls.
FuncGraphPtr FetchFuncGraphByCallDepth(const FuncGraphPtr &func_graph, const size_t call_depth) {
if (call_depth <= kSingleCallDepth) {
return func_graph;
}
const auto &output = func_graph->output();
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
const auto &cnode = output->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (inputs.size() < kPartialInputStartPos) {
MS_LOG(EXCEPTION) << "Invalid partial node:" << AnfAlgo::GetNodeDebugString(output);
}
const auto &called_func_graph = GetValueNode<FuncGraphPtr>(inputs[kPartialFuncGraphPos]);
return FetchFuncGraphByCallDepth(called_func_graph, call_depth - kSingleCallDepth);
} else if (output->isa<ValueNode>() && IsValueNode<FuncGraph>(output)) {
return FetchFuncGraphByCallDepth(GetValueNode<FuncGraphPtr>(output), call_depth - kSingleCallDepth);
} else {
MS_LOG(EXCEPTION) << "Invalid output for call depth:" << call_depth << " funcgraph:" << func_graph->ToString()
<< " output node:" << AnfAlgo::GetNodeDebugString(output);
}
}
// Get funcgraph from node, the interface only accepts partial node and funcgraph value node.
FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
@ -516,6 +484,26 @@ FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
}
} // namespace
AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes) {
const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
if (!IsCallNode(real_node)) {
return real_node;
}
if ((*call_nodes).find(real_node) != (*call_nodes).end()) {
return nullptr;
}
(*call_nodes).insert(real_node);
const auto &func_graphs = FetchFuncGraphbyCallNode(real_node);
for (const auto &func_graph : func_graphs) {
const auto &output = FetchRealOutputByCallNode(func_graph->output(), call_nodes);
if (output != nullptr) {
return output;
}
}
return nullptr;
}
// Return true if the node has Ref abstract.
bool HasAbstractRef(const AnfNodePtr &node) {
if (node == nullptr) {
@ -602,6 +590,8 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
for (size_t i = kSwitchTrueBranchPos; i < cnode_inputs.size(); ++i) {
if (IsPrimitiveCNode(cnode_inputs[i], prim::kPrimPartial)) {
func_graphs.emplace_back(GetFuncGraphFromPartial(cnode_inputs[i]));
} else if (IsValueNode<FuncGraph>(cnode_inputs[i])) {
func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(cnode_inputs[i]));
}
}
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer) &&
@ -659,6 +649,11 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr>
break;
}
total_num += call_output_num;
} else if (inputs[i]->isa<ValueNode>() && inputs[i]->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
auto value_tuple = inputs[i]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
auto tuple_value = value_tuple->value();
total_num += tuple_value.size();
} else if (!HasAbstractMonad(inputs[i])) {
++total_num;
}
@ -925,8 +920,12 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
}
if (front_to_backend_parameters_.find(parameters[i - kCallInputStartPos]) ==
front_to_backend_parameters_.end()) {
MS_LOG(EXCEPTION) << "Cannot find backend parameter for front parameter:"
<< AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos]);
MS_LOG(INFO) << "Cannot find backend parameter for front parameter:"
<< AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos])
<< ", used the default format";
CreateDeviceTensorForFrontParameter(inputs[i], device_contexts[0]);
front_value_nodes_.push_back({inputs[i], device_contexts[0]});
continue;
}
const auto &backend_node = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].first;
@ -1110,9 +1109,6 @@ void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &con
for (const auto &func_graph : func_graphs) {
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->output()->isa<ValueNode>()) {
continue;
}
if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) {
func_graph_to_call_num_[func_graph] = 1;
@ -1147,13 +1143,6 @@ void ControlNodeParser::CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr>
for (const auto &control_node : control_nodes) {
// Root funcgraph does not need to create a gather actor.
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
const auto &cnode = control_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
// If the output of funcgraph is a value node, no need to create gather actor.
if (inputs[kReturnInputPos]->isa<ValueNode>()) {
continue;
}
auto func_graph = control_node->func_graph();
func_graph_to_branch_id_[func_graph] = branch_id++;
}

View File

@ -53,6 +53,9 @@ bool IsCallNode(const AnfNodePtr &node);
// Check if the call node is the input of another call node.
bool IsSubCallNode(const AnfNodePtr &node);
// Recursive interface, find the real output of funcgraph called by call node.
AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes);
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
// it is determined whether it is a weight.
bool HasAbstractRef(const AnfNodePtr &node);

View File

@ -1253,12 +1253,9 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
continue;
}
// If the output of funcgraph is a value node, no need to create gather actor.
if (inputs[kReturnInputPos]->isa<ValueNode>() ||
AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
if (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
continue;
}
auto actor_name = func_graph->ToString();
std::vector<KernelWithIndex> parameters;
for (const auto &parameter : func_graph->get_inputs()) {
@ -2079,69 +2076,14 @@ void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr
// Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
if (inputs[0]->isa<CNode>()) {
// Add the input of call node to switch actor.
if (IsCallNode(inputs[0])) {
const auto &sub_call_cnode = inputs[0]->cast<CNodePtr>();
const auto &sub_inputs = sub_call_cnode->inputs();
if (AnfAlgo::CheckPrimitiveType(sub_inputs[0], prim::kPrimSwitchLayer)) {
auto actor = FetchActor(sub_inputs[0]->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
switch_actor->AddCommonInput(inputs[i]);
}
}
} else if (IsSubCallNode(cnode)) {
// Add the input of sub call node to switch actor.
auto actor = FetchActor(inputs[0]->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
const auto &tuple_node = inputs[0]->cast<CNodePtr>()->input(kSwitchLayerBranchPos);
const auto &tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
FuncGraphPtr func_graph = nullptr;
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
int pre_real_parameter_num = 0;
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
pre_real_parameter_num = (tuple_inputs[i]->cast<CNodePtr>()->inputs().size() - kPartialInputStartPos);
func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[i]->cast<CNodePtr>()->input(kPartialFuncGraphPos));
} else {
func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[i]);
}
const auto parameters = func_graph->parameters();
const auto &output = func_graph->output();
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
const auto &sub_partial_inputs = output->cast<CNodePtr>()->inputs();
// Check whether the input node of the sub call node needs to be added to the switch actor. Only when
// the final return is a partial node and the partial node needs this input, the input node is added
// to the switch actor/
for (size_t j = kPartialInputStartPos; j < sub_partial_inputs.size(); ++j) {
const auto &real_partial_input = AnfAlgo::VisitKernelWithReturnType(sub_partial_inputs[j], 0).first;
const auto &iter = find(parameters.begin(), parameters.end(), real_partial_input);
if ((iter != parameters.end()) && (iter - parameters.begin() >= pre_real_parameter_num) &&
(iter - parameters.begin() <
SizeToInt(pre_real_parameter_num + inputs.size() - kCallInputStartPos))) {
size_t pos = iter - parameters.begin() - pre_real_parameter_num + kCallInputStartPos;
switch_actor->AddSingleInput(inputs[pos], i - 1);
}
}
}
}
} else {
auto actor = FetchActor(inputs[0]->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (HasAbstractMonad(inputs[i])) {
continue;
}
switch_actor->AddCommonInput(inputs[i]);
auto actor = FetchActor(inputs[0]->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (HasAbstractMonad(inputs[i])) {
continue;
}
switch_actor->AddCommonInput(inputs[i]);
}
}
}
@ -2219,10 +2161,10 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
for (const auto &input_with_index : gather_actor->data_nodes_) {
for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
const auto &input_with_index = gather_actor->data_nodes_[i];
const auto &from_func_graph = kernel_graph->GetFuncGraph();
LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor,
gather_actor->FetchDataNodePosition(input_with_index));
LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor, i);
}
}
LinkBranchArrowForSwitchActor(graph_compiler_info, actor_set);
@ -2258,47 +2200,9 @@ void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_w
// Fetch all the funcgraph that call node would call.
const auto cnode = call_node_with_index.first->cast<CNodePtr>();
std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
const auto &call_inputs = cnode->inputs();
auto switch_node = call_inputs[0];
if (IsCallNode(switch_node)) {
switch_node = call_inputs[0]->cast<CNodePtr>()->input(0);
}
// Collect the output of each funcgraph.
for (const auto &func_graph : func_graphs) {
if (func_graph->output()->isa<ValueNode>()) {
if (AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitchLayer)) {
const auto &actor_name = switch_node->DebugString();
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
const auto &output_with_index = KernelWithIndex(func_graph->output(), 0);
const auto &iter =
find(switch_actor->input_nodes_.begin(), switch_actor->input_nodes_.end(), output_with_index);
if (iter == switch_actor->input_nodes_.end()) {
MS_LOG(EXCEPTION) << "Invalid input node for switch actor:" << switch_actor->GetAID()
<< " node:" << AnfAlgo::GetNodeDebugString(func_graph->output());
}
size_t pos = iter - switch_actor->input_nodes_.begin();
// Add output for each branch of switch.
for (size_t i = 0; i < switch_actor->branch_inputs_pos_.size(); ++i) {
const auto poses = switch_actor->branch_inputs_pos_[i];
if (find(poses.begin(), poses.end(), pos) == poses.end()) {
continue;
}
auto op_arrow = std::make_shared<DataArrow>(pos, to_actor->GetAID(), to_index);
switch_actor->output_branch_arrows_[i].emplace_back(op_arrow);
}
} else {
MS_LOG(EXCEPTION) << "Invalid funcgraph:" << func_graph->ToString();
}
continue;
}
const auto actor_name = func_graph->get_return()->DebugString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
@ -2443,7 +2347,7 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
// Link switch output.
for (size_t i = 0; i < actor->branch_func_graph_.size(); ++i) {
auto func_graph = actor->branch_func_graph_[i];
if (func_graph == nullptr || func_graph->output()->isa<ValueNode>()) {
if (func_graph == nullptr) {
continue;
}
@ -2547,6 +2451,18 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr>
// If there is no output from the switch actor branch, it means that the subgraph has no input,
// and need to connect a control arrow to the corresponding gather actor.
for (auto &switch_actor : (*switch_actors)) {
if (AnfAlgo::CheckPrimitiveType(switch_actor->node_, prim::kPrimReturn)) {
const auto &func_graph = switch_actor->node_->func_graph();
if (func_graph->output()->isa<ValueNode>()) {
const auto &actor_name = func_graph->ToString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
switch_actor->input_controls_num_++;
}
}
for (size_t i = 0; i < switch_actor->output_branch_arrows_.size(); ++i) {
const auto &arrows = switch_actor->output_branch_arrows_[i];
if (arrows.empty() && switch_actor->branch_func_graph_[i] != nullptr) {
@ -2603,7 +2519,7 @@ void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &grap
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
const auto &func_graph = switch_actor->branch_func_graph_[i];
if (func_graph == nullptr || func_graph->output()->isa<ValueNode>()) {
if (func_graph == nullptr) {
continue;
}

View File

@ -258,7 +258,6 @@ class GraphScheduler {
// send the branch id to the loop count actor.
void LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
void PrepareDataForControlNode(HostQueueDataSourceActor *host_data_source_actor,
const ControlNodeParserPtr &control_node_parser,