!20669 Fix codedex and pclint.

Merge pull request !20669 from gaoyong10/runtime_001
This commit is contained in:
i-robot 2021-07-22 13:24:09 +00:00 committed by Gitee
commit 138f381829
6 changed files with 59 additions and 46 deletions

View File

@ -38,7 +38,7 @@ void GatherActor::Init() {
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
output_data_.emplace_back(data.get());
output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(std::move(data));
output_data_by_output_index_[IntToSize(data_arrow->from_output_index_)].emplace_back(std::move(data));
}
}
@ -113,7 +113,7 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
// 2.Send output result.
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
size_t from_index = result_arrow->from_output_index_;
size_t from_index = IntToSize(result_arrow->from_output_index_);
const auto &front_node = data_nodes_[from_index].first;
for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) {
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() ==

View File

@ -85,7 +85,7 @@ class GatherActor : public OpActor<DeviceTensor> {
// The device tensors for launch.
std::vector<DeviceTensor *> input_device_tensors_;
// The branch if for current step.
int input_branch_id_;
int input_branch_id_{kInvalidBranchID};
// Input data.
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;

View File

@ -208,7 +208,7 @@ void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t b
branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin());
return;
}
device_tensor_store_keys_.push_back({input_nodes_.size(), node.get()});
device_tensor_store_keys_.emplace_back(input_nodes_.size(), node.get());
branch_inputs_pos_[branch].push_back(input_nodes_.size());
input_nodes_.push_back(node_with_index);
return;
@ -245,7 +245,7 @@ void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
} else if (IsCallNode(real_input.first)) {
std::vector<AnfNodePtr> call_nodes;
const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes);
if (call_output_num <= 0) {
if (call_output_num == 0) {
MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first);
}
for (size_t i = 0; i < call_output_num; ++i) {
@ -268,7 +268,7 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
input_branch_ids_[context->sequential_num_].empty()) {
MS_LOG(ERROR) << "Invalid branch id for actor:" + GetAID().Name();
}
size_t branch_id = input_branch_ids_[context->sequential_num_].top();
auto branch_id = input_branch_ids_[context->sequential_num_].top();
input_branch_ids_[context->sequential_num_].pop();
if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) {
MS_LOG(ERROR) << "Invalid branch id for switch actor:" + GetAID().Name() +
@ -278,9 +278,8 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
}
DeviceTensor *device_tensor = input_device_tensors_[0];
if (device_tensor == nullptr) {
MS_LOG(ERROR) << "Index of switch actor is empty:" + GetAID().Name();
}
MS_EXCEPTION_IF_NULL(device_tensor);
auto inputs = node_->inputs();
TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
size_t size = abstract::TypeIdSize(type_id);
@ -291,7 +290,10 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
int64_t index = 0;
char buf[kMaxSwitchCondSize] = {0};
ShapeVector host_shape;
device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf));
if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) {
MS_LOG(ERROR) << GetAID().Name() + " get index from device address failed, type id:" + std::to_string(type_id) +
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
}
if (type_id == TypeId::kNumberTypeInt32) {
index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
@ -306,7 +308,7 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
// SwitchLayer node support negative index range [-size, -1].
if (index < 0) {
index += branch_func_graph_.size();
index += SizeToInt(branch_func_graph_.size());
}
return static_cast<size_t>(index);
}
@ -403,7 +405,7 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
" total:" + std::to_string(branch_inputs_pos_[index].size()) + " actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
size_t from_index = branch_inputs_pos_[index][result_arrow->from_output_index_];
size_t from_index = branch_inputs_pos_[index][IntToSize(result_arrow->from_output_index_)];
MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index];
bool is_send = false;

View File

@ -556,7 +556,7 @@ bool IsSubCallNode(const AnfNodePtr &node) {
std::vector<KernelWithIndex> FetchAllRealInputNodeByParameter(const KernelWithIndex &node) {
std::vector<KernelWithIndex> parameters;
const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, SizeToInt(node.second));
const auto &real_node = real_node_with_index.first;
if (real_node->isa<Parameter>()) {
if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) {
@ -749,7 +749,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
}
}
FetchFrontToBackendParameter(graphs, device_contexts, control_nodes, real_to_formal_front_parameters,
FetchFrontToBackendParameter(graphs, device_contexts, real_to_formal_front_parameters,
formal_to_real_front_parameters);
FetchFuncGraphToParameter(control_nodes);
@ -924,14 +924,14 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
<< AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos])
<< ", used the default format";
CreateDeviceTensorForFrontParameter(inputs[i], device_contexts[0]);
front_value_nodes_.push_back({inputs[i], device_contexts[0]});
front_value_nodes_.emplace_back(inputs[i], device_contexts[0]);
continue;
}
const auto &backend_node = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].first;
const auto &device_context = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].second;
CreateDeviceTensorForValueNode(inputs[i], backend_node, device_context);
front_value_nodes_.push_back({inputs[i], device_context});
front_value_nodes_.emplace_back(inputs[i], device_context);
}
}
}
@ -940,18 +940,15 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
for (size_t index = 0; index < graphs.size(); ++index) {
const auto &graph = graphs[index];
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (const auto &parameter : graph->input_nodes()) {
const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter);
const auto &internal_node = graph->GetFrontNodeByInternalParameter(parameter);
MS_EXCEPTION_IF_NULL(parameter);
if (IsInternalParameter(parameter, graph)) {
auto front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter);
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
const auto &front_output_with_index =
AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
const auto &front_output_with_index = AnfAlgo::VisitKernelWithReturnType(
front_node_with_index.first, SizeToInt(front_node_with_index.second), false);
auto front_output_node = front_output_with_index.first;
MS_EXCEPTION_IF_NULL(front_output_node);
if (AnfAlgo::CheckPrimitiveType(front_output_node, prim::kPrimSwitch)) {
@ -959,7 +956,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
FetchValueNodeBySwitchNode(front_output_node, &value_nodes);
for (const auto value_node : value_nodes) {
CreateDeviceTensorForValueNode(value_node, parameter, device_contexts[index]);
front_value_nodes_.push_back({value_node, device_contexts[index]});
front_value_nodes_.emplace_back(value_node, device_contexts[index]);
}
}
}
@ -975,7 +972,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
if (output->isa<ValueNode>() && GetFrontValueNodeDeviceContext(output) == nullptr) {
const auto &device_context = call_node_to_backend_parameter.second.second;
CreateDeviceTensorForValueNode(output, call_node_to_backend_parameter.second.first, device_context);
front_value_nodes_.push_back({output, device_context});
front_value_nodes_.emplace_back(output, device_context);
}
}
}
@ -1054,7 +1051,7 @@ void ControlNodeParser::FetchFrontToFrontParameter(
if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
std::vector<AnfNodePtr> call_inputs;
call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end());
call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end());
switch_input_parse(inputs[0], call_inputs);
} else if (IsCallNode(inputs[0])) {
continue;
@ -1095,7 +1092,7 @@ std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::
if (backend_iter == front_to_backend_parameters_.end()) {
CreateDeviceTensorForFrontParameter(parameter, device_context);
front_to_backend_parameters_[parameter] = {parameter, device_context};
front_parameters_.push_back({parameter, device_context});
front_parameters_.emplace_back(parameter, device_context);
}
}
@ -1215,7 +1212,6 @@ std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &g
void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters) {
if (graphs.size() != device_contexts.size()) {
@ -1356,12 +1352,12 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
std::set<AnfNodePtr> *switch_nodes,
std::set<KernelWithIndex> *results) {
if (front_output->isa<ValueNode>()) {
(*results).insert({front_output, 0});
(*results).emplace(front_output, 0);
const auto &iter = formal_to_real_parameters_.find(front_output);
if (iter != formal_to_real_parameters_.end()) {
for (const auto &node : iter->second) {
(*results).insert(node);
(*results).emplace(node);
}
}
} else if (front_output->isa<Parameter>()) {
@ -1370,7 +1366,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
if (iter != formal_to_real_parameters_.end()) {
for (const auto &node : iter->second) {
(*results).insert(node);
(*results).emplace(node);
}
} else {
MS_LOG(EXCEPTION) << "Cannot find backend node for front parameter:" << AnfAlgo::GetNodeDebugString(front_output);
@ -1402,7 +1398,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
const auto iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(front_output, 0));
if (iter != front_to_backend_kernels_.end()) {
(*results).insert(iter->second.first);
(*results).emplace(iter->second.first);
} else {
MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output);
}
@ -1429,7 +1425,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
MS_LOG(WARNING) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
continue;
}
formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, 0});
formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first, 0);
} else {
const auto iter = front_to_backend_kernels_.find(node_with_index);
if (iter == front_to_backend_kernels_.end()) {
@ -1439,7 +1435,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
}
}
} else if (real_parameter->isa<ValueNode>()) {
formal_to_real_parameters_[formal_parameter].push_back({real_parameter, 0});
formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0);
} else if (IsCallNode(real_parameter)) {
const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter);
for (const auto func_graph : func_graphs) {
@ -1512,7 +1508,7 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
auto front_node = graph->GetFrontAnfByBackendAnf(value_node);
if (front_node != nullptr) {
formal_to_real_parameters_[front_node].push_back({value_node, 0});
formal_to_real_parameters_[front_node].emplace_back(value_node, 0);
}
}
}
@ -1521,7 +1517,7 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
for (const auto &front_weight : host_parameter_to_weight.second) {
const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first);
if (iter != front_to_backend_parameters_.end()) {
formal_to_real_parameters_[front_weight].push_back({iter->second.first, 0});
formal_to_real_parameters_[front_weight].emplace_back(iter->second.first, 0);
}
}
}
@ -1552,10 +1548,10 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
}
}
for (const auto parameter_pair : front_to_backend_parameters) {
formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0});
formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0);
}
for (const auto parameter_pair : front_to_backend_parameters_) {
formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0});
formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0);
}
}

View File

@ -150,7 +150,6 @@ class ControlNodeParser {
// 2. The parameter from control nodes.
void FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<AnfNodePtr> &control_nodes,
const RealToFormalNode &real_to_formal_front_parameters,
const RealToFormalNode &formal_to_real_front_parameters);
// Get the relationship between the front and backend of the executable kernel in all kernel graphs.

View File

@ -1262,7 +1262,7 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) {
continue;
}
parameters.push_back({parameter, 0});
parameters.emplace_back(parameter, 0);
}
const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
@ -1291,7 +1291,7 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) {
continue;
}
parameters.push_back({inputs[i], 0});
parameters.emplace_back(inputs[i], 0);
}
auto func_graph = control_node->func_graph();
@ -1335,7 +1335,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
if (from_kernel->isa<Parameter>() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) {
const auto &kernel_with_index = GetFrontNodeByKernelGraph(from_kernel, graph);
const auto &real_front_node_with_index =
AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, kernel_with_index.second);
AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, SizeToInt(kernel_with_index.second));
if (HasAbstractRef(real_front_node_with_index.first)) {
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
real_front_node_with_index.first.get());
@ -1966,6 +1966,7 @@ void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
// Set branch index into switch actor.
size_t branch_index = switch_actor->branch_id_to_index_.size();
@ -2071,7 +2072,6 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes) {
for (const auto &node : control_nodes) {
CNodePtr cnode = node->cast<CNodePtr>();
const auto &from_func_graph = node->func_graph();
auto inputs = cnode->inputs();
// Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
@ -2079,6 +2079,8 @@ void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr
auto actor = FetchActor(inputs[0]->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (HasAbstractMonad(inputs[i])) {
continue;
@ -2107,6 +2109,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(func_graph);
@ -2123,8 +2126,8 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
continue;
}
gather_actor->device_tensor_store_keys_.push_back(
{i - kCallInputStartPos - persist_input_num, inputs[i].get()});
gather_actor->device_tensor_store_keys_.emplace_back(i - kCallInputStartPos - persist_input_num,
inputs[i].get());
gather_actor->device_contexts_[i - kCallInputStartPos - persist_input_num] =
graph_compiler_info.control_node_parser_->GetFrontValueNodeDeviceContext(inputs[i]);
} else if ((inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]->cast<ParameterPtr>())) ||
@ -2150,7 +2153,9 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
MS_EXCEPTION_IF_NULL(func_graph);
auto actor = FetchActor(func_graph->get_return()->DebugString());
MS_EXCEPTION_IF_NULL(actor);
LinkDataArrowForSwitchActor(graph_compiler_info, dynamic_cast<SwitchActor *>(actor));
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
}
// Link arrow for gather actor for call input kernel graph.
@ -2160,6 +2165,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
auto actor = FetchActor(kernel_graph->ToString());
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
const auto &input_with_index = gather_actor->data_nodes_[i];
@ -2286,8 +2292,10 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
const auto &actor_name = backend_node.first->fullname_with_scope();
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
auto from_actor = dynamic_cast<KernelActor *>(actor);
MS_EXCEPTION_IF_NULL(from_actor);
auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
from_actor->output_data_arrows_.emplace_back(op_arrow);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false);
UpdateRefCount(device_tensor.get(), true);
@ -2416,6 +2424,8 @@ void GraphScheduler::LinkControlArrowForGatherActor(std::vector<KernelActorPtr>
actor = FetchActor(func_graph->get_return()->DebugString());
if (actor != nullptr) {
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
kernel_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
switch_actor->input_controls_num_++;
}
@ -2458,6 +2468,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr>
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
switch_actor->input_controls_num_++;
}
@ -2470,6 +2481,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr>
const auto &actor = FetchActor(actor_name);
if (actor != nullptr) {
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID());
gather_actor->input_controls_num_++;
}
@ -2495,6 +2507,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr>
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
size_t branch_index = switch_actor->branch_id_to_index_.size();
if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
@ -2517,6 +2530,8 @@ void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &grap
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
const auto &func_graph = switch_actor->branch_func_graph_[i];
if (func_graph == nullptr) {
@ -2556,6 +2571,7 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
gather_actor->output_branch_arrows_.emplace_back(gather_actor->switch_aid_);
}
}