diff --git a/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc b/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc index efa0bc6bc61..fb35a021092 100644 --- a/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc @@ -38,7 +38,7 @@ void GatherActor::Init() { auto data = std::make_unique>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_); output_data_.emplace_back(data.get()); - output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(std::move(data)); + output_data_by_output_index_[IntToSize(data_arrow->from_output_index_)].emplace_back(std::move(data)); } } @@ -113,7 +113,7 @@ void GatherActor::SendOutput(OpContext *context) const { // 2.Send output result. for (const auto &result_arrow : output_result_arrows_) { MS_EXCEPTION_IF_NULL(result_arrow); - size_t from_index = result_arrow->from_output_index_; + size_t from_index = IntToSize(result_arrow->from_output_index_); const auto &front_node = data_nodes_[from_index].first; for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) { if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() == diff --git a/mindspore/ccsrc/runtime/framework/actor/gather_actor.h b/mindspore/ccsrc/runtime/framework/actor/gather_actor.h index de431eee227..3a0f45de737 100644 --- a/mindspore/ccsrc/runtime/framework/actor/gather_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/gather_actor.h @@ -85,7 +85,7 @@ class GatherActor : public OpActor { // The device tensors for launch. std::vector input_device_tensors_; // The branch if for current step. - int input_branch_id_; + int input_branch_id_{kInvalidBranchID}; // Input data. std::unordered_map>> input_data_; diff --git a/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc b/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc index 6623444c117..5ff8be84fd5 100644 --- a/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc @@ -208,7 +208,7 @@ void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t b branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin()); return; } - device_tensor_store_keys_.push_back({input_nodes_.size(), node.get()}); + device_tensor_store_keys_.emplace_back(input_nodes_.size(), node.get()); branch_inputs_pos_[branch].push_back(input_nodes_.size()); input_nodes_.push_back(node_with_index); return; @@ -245,7 +245,7 @@ void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) { } else if (IsCallNode(real_input.first)) { std::vector call_nodes; const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes); - if (call_output_num <= 0) { + if (call_output_num == 0) { MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first); } for (size_t i = 0; i < call_output_num; ++i) { @@ -268,7 +268,7 @@ size_t SwitchActor::GetIndex(OpContext *context) { input_branch_ids_[context->sequential_num_].empty()) { MS_LOG(ERROR) << "Invalid branch id for actor:" + GetAID().Name(); } - size_t branch_id = input_branch_ids_[context->sequential_num_].top(); + auto branch_id = input_branch_ids_[context->sequential_num_].top(); input_branch_ids_[context->sequential_num_].pop(); if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) { MS_LOG(ERROR) << "Invalid branch id for switch actor:" + GetAID().Name() + @@ -278,9 +278,8 @@ size_t SwitchActor::GetIndex(OpContext *context) { } DeviceTensor *device_tensor = input_device_tensors_[0]; - if (device_tensor == nullptr) { - MS_LOG(ERROR) << "Index of switch actor is empty:" + GetAID().Name(); - } + MS_EXCEPTION_IF_NULL(device_tensor); + auto inputs = node_->inputs(); TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0); size_t size = abstract::TypeIdSize(type_id); @@ -291,7 +290,10 @@ size_t SwitchActor::GetIndex(OpContext *context) { int64_t index = 0; char buf[kMaxSwitchCondSize] = {0}; ShapeVector host_shape; - device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast(buf)); + if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast(buf))) { + MS_LOG(ERROR) << GetAID().Name() + " get index from device address failed, type id:" + std::to_string(type_id) + + ", device type:" + std::to_string(static_cast(device_context_->GetDeviceAddressType())); + } if (type_id == TypeId::kNumberTypeInt32) { index = static_cast((static_cast(static_cast(buf)))[0]); @@ -306,7 +308,7 @@ size_t SwitchActor::GetIndex(OpContext *context) { // SwitchLayer node support negative index range [-size, -1]. if (index < 0) { - index += branch_func_graph_.size(); + index += SizeToInt(branch_func_graph_.size()); } return static_cast(index); } @@ -403,7 +405,7 @@ void SwitchActor::SendOutput(OpContext *context) { " total:" + std::to_string(branch_inputs_pos_[index].size()) + " actor:" + GetAID().Name(); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } - size_t from_index = branch_inputs_pos_[index][result_arrow->from_output_index_]; + size_t from_index = branch_inputs_pos_[index][IntToSize(result_arrow->from_output_index_)]; MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index]; bool is_send = false; diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index 8e8dd15fe3c..0110bc9fc78 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -556,7 +556,7 @@ bool IsSubCallNode(const AnfNodePtr &node) { std::vector FetchAllRealInputNodeByParameter(const KernelWithIndex &node) { std::vector parameters; - const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second); + const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, SizeToInt(node.second)); const auto &real_node = real_node_with_index.first; if (real_node->isa()) { if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) { @@ -749,7 +749,7 @@ void ControlNodeParser::Parse(const std::vector &control_nodes, cons } } - FetchFrontToBackendParameter(graphs, device_contexts, control_nodes, real_to_formal_front_parameters, + FetchFrontToBackendParameter(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters); FetchFuncGraphToParameter(control_nodes); @@ -924,14 +924,14 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr << AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos]) << ", used the default format"; CreateDeviceTensorForFrontParameter(inputs[i], device_contexts[0]); - front_value_nodes_.push_back({inputs[i], device_contexts[0]}); + front_value_nodes_.emplace_back(inputs[i], device_contexts[0]); continue; } const auto &backend_node = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].first; const auto &device_context = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].second; CreateDeviceTensorForValueNode(inputs[i], backend_node, device_context); - front_value_nodes_.push_back({inputs[i], device_context}); + front_value_nodes_.emplace_back(inputs[i], device_context); } } } @@ -940,18 +940,15 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr for (size_t index = 0; index < graphs.size(); ++index) { const auto &graph = graphs[index]; MS_EXCEPTION_IF_NULL(graph); - auto execution_order = graph->execution_order(); for (const auto ¶meter : graph->input_nodes()) { - const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter); - const auto &internal_node = graph->GetFrontNodeByInternalParameter(parameter); - MS_EXCEPTION_IF_NULL(parameter); + if (IsInternalParameter(parameter, graph)) { auto front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter); MS_EXCEPTION_IF_NULL(front_node_with_index.first); - const auto &front_output_with_index = - AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false); + const auto &front_output_with_index = AnfAlgo::VisitKernelWithReturnType( + front_node_with_index.first, SizeToInt(front_node_with_index.second), false); auto front_output_node = front_output_with_index.first; MS_EXCEPTION_IF_NULL(front_output_node); if (AnfAlgo::CheckPrimitiveType(front_output_node, prim::kPrimSwitch)) { @@ -959,7 +956,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr FetchValueNodeBySwitchNode(front_output_node, &value_nodes); for (const auto value_node : value_nodes) { CreateDeviceTensorForValueNode(value_node, parameter, device_contexts[index]); - front_value_nodes_.push_back({value_node, device_contexts[index]}); + front_value_nodes_.emplace_back(value_node, device_contexts[index]); } } } @@ -975,7 +972,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr if (output->isa() && GetFrontValueNodeDeviceContext(output) == nullptr) { const auto &device_context = call_node_to_backend_parameter.second.second; CreateDeviceTensorForValueNode(output, call_node_to_backend_parameter.second.first, device_context); - front_value_nodes_.push_back({output, device_context}); + front_value_nodes_.emplace_back(output, device_context); } } } @@ -1054,7 +1051,7 @@ void ControlNodeParser::FetchFrontToFrontParameter( if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch) || AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { std::vector call_inputs; - call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end()); + call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end()); switch_input_parse(inputs[0], call_inputs); } else if (IsCallNode(inputs[0])) { continue; @@ -1095,7 +1092,7 @@ std::vector ControlNodeParser::FetchControlNodeParameter(const std:: if (backend_iter == front_to_backend_parameters_.end()) { CreateDeviceTensorForFrontParameter(parameter, device_context); front_to_backend_parameters_[parameter] = {parameter, device_context}; - front_parameters_.push_back({parameter, device_context}); + front_parameters_.emplace_back(parameter, device_context); } } @@ -1215,7 +1212,6 @@ std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &g void ControlNodeParser::FetchFrontToBackendParameter(const std::vector &graphs, const std::vector &device_contexts, - const std::vector &control_nodes, const RealToFormalNode &real_to_formal_front_parameters, const RealToFormalNode &formal_to_real_front_parameters) { if (graphs.size() != device_contexts.size()) { @@ -1356,12 +1352,12 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_ std::set *switch_nodes, std::set *results) { if (front_output->isa()) { - (*results).insert({front_output, 0}); + (*results).emplace(front_output, 0); const auto &iter = formal_to_real_parameters_.find(front_output); if (iter != formal_to_real_parameters_.end()) { for (const auto &node : iter->second) { - (*results).insert(node); + (*results).emplace(node); } } } else if (front_output->isa()) { @@ -1370,7 +1366,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_ if (iter != formal_to_real_parameters_.end()) { for (const auto &node : iter->second) { - (*results).insert(node); + (*results).emplace(node); } } else { MS_LOG(EXCEPTION) << "Cannot find backend node for front parameter:" << AnfAlgo::GetNodeDebugString(front_output); @@ -1402,7 +1398,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_ const auto iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(front_output, 0)); if (iter != front_to_backend_kernels_.end()) { - (*results).insert(iter->second.first); + (*results).emplace(iter->second.first); } else { MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output); } @@ -1429,7 +1425,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode( MS_LOG(WARNING) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first); continue; } - formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, 0}); + formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first, 0); } else { const auto iter = front_to_backend_kernels_.find(node_with_index); if (iter == front_to_backend_kernels_.end()) { @@ -1439,7 +1435,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode( } } } else if (real_parameter->isa()) { - formal_to_real_parameters_[formal_parameter].push_back({real_parameter, 0}); + formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0); } else if (IsCallNode(real_parameter)) { const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter); for (const auto func_graph : func_graphs) { @@ -1512,7 +1508,7 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector auto front_node = graph->GetFrontAnfByBackendAnf(value_node); if (front_node != nullptr) { - formal_to_real_parameters_[front_node].push_back({value_node, 0}); + formal_to_real_parameters_[front_node].emplace_back(value_node, 0); } } } @@ -1521,7 +1517,7 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector for (const auto &front_weight : host_parameter_to_weight.second) { const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first); if (iter != front_to_backend_parameters_.end()) { - formal_to_real_parameters_[front_weight].push_back({iter->second.first, 0}); + formal_to_real_parameters_[front_weight].emplace_back(iter->second.first, 0); } } } @@ -1552,10 +1548,10 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector } } for (const auto parameter_pair : front_to_backend_parameters) { - formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0}); + formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0); } for (const auto parameter_pair : front_to_backend_parameters_) { - formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0}); + formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0); } } diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index 01ef69e501f..51d06a3f29d 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -150,7 +150,6 @@ class ControlNodeParser { // 2. The parameter from control nodes. void FetchFrontToBackendParameter(const std::vector &graphs, const std::vector &device_contexts, - const std::vector &control_nodes, const RealToFormalNode &real_to_formal_front_parameters, const RealToFormalNode &formal_to_real_front_parameters); // Get the relationship between the front and backend of the executable kernel in all kernel graphs. diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index c7106511a57..240f7008f45 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -1262,7 +1262,7 @@ std::vector GraphScheduler::BuildGatherActor(const GraphCompiler if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) { continue; } - parameters.push_back({parameter, 0}); + parameters.emplace_back(parameter, 0); } const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph); @@ -1291,7 +1291,7 @@ std::vector GraphScheduler::BuildGatherActor(const GraphCompiler if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa() && HasAbstractRef(inputs[i]))) { continue; } - parameters.push_back({inputs[i], 0}); + parameters.emplace_back(inputs[i], 0); } auto func_graph = control_node->func_graph(); @@ -1335,7 +1335,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf if (from_kernel->isa() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) { const auto &kernel_with_index = GetFrontNodeByKernelGraph(from_kernel, graph); const auto &real_front_node_with_index = - AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, kernel_with_index.second); + AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, SizeToInt(kernel_with_index.second)); if (HasAbstractRef(real_front_node_with_index.first)) { to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, real_front_node_with_index.first.get()); @@ -1966,6 +1966,7 @@ void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo auto actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); auto switch_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(switch_actor); // Set branch index into switch actor. size_t branch_index = switch_actor->branch_id_to_index_.size(); @@ -2071,7 +2072,6 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector &control_nodes) { for (const auto &node : control_nodes) { CNodePtr cnode = node->cast(); - const auto &from_func_graph = node->func_graph(); auto inputs = cnode->inputs(); // Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor. @@ -2079,6 +2079,8 @@ void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vectorDebugString()); MS_EXCEPTION_IF_NULL(actor); auto switch_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(switch_actor); + for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { if (HasAbstractMonad(inputs[i])) { continue; @@ -2107,6 +2109,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi auto actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); auto gather_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(gather_actor); const auto &func_graph = GetValueNode(inputs[0]); MS_EXCEPTION_IF_NULL(func_graph); @@ -2123,8 +2126,8 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi continue; } - gather_actor->device_tensor_store_keys_.push_back( - {i - kCallInputStartPos - persist_input_num, inputs[i].get()}); + gather_actor->device_tensor_store_keys_.emplace_back(i - kCallInputStartPos - persist_input_num, + inputs[i].get()); gather_actor->device_contexts_[i - kCallInputStartPos - persist_input_num] = graph_compiler_info.control_node_parser_->GetFrontValueNodeDeviceContext(inputs[i]); } else if ((inputs[i]->isa() && HasAbstractRef(inputs[i]->cast())) || @@ -2150,7 +2153,9 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi MS_EXCEPTION_IF_NULL(func_graph); auto actor = FetchActor(func_graph->get_return()->DebugString()); MS_EXCEPTION_IF_NULL(actor); - LinkDataArrowForSwitchActor(graph_compiler_info, dynamic_cast(actor)); + auto switch_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(switch_actor); + LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor); } // Link arrow for gather actor for call input kernel graph. @@ -2160,6 +2165,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi auto actor = FetchActor(kernel_graph->ToString()); MS_EXCEPTION_IF_NULL(actor); auto gather_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(gather_actor); for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) { const auto &input_with_index = gather_actor->data_nodes_[i]; @@ -2286,8 +2292,10 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c const auto &actor_name = backend_node.first->fullname_with_scope(); const auto &actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); - auto op_arrow = std::make_shared(backend_node.second, to_actor->GetAID(), to_index); auto from_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(from_actor); + + auto op_arrow = std::make_shared(backend_node.second, to_actor->GetAID(), to_index); from_actor->output_data_arrows_.emplace_back(op_arrow); auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false); UpdateRefCount(device_tensor.get(), true); @@ -2416,6 +2424,8 @@ void GraphScheduler::LinkControlArrowForGatherActor(std::vector actor = FetchActor(func_graph->get_return()->DebugString()); if (actor != nullptr) { auto switch_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(switch_actor); + kernel_actor->output_control_arrows_.emplace_back(switch_actor->GetAID()); switch_actor->input_controls_num_++; } @@ -2458,6 +2468,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector auto actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); auto gather_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(gather_actor); gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID()); switch_actor->input_controls_num_++; } @@ -2470,6 +2481,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector const auto &actor = FetchActor(actor_name); if (actor != nullptr) { const auto &gather_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(gather_actor); switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID()); gather_actor->input_controls_num_++; } @@ -2495,6 +2507,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector auto actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); auto switch_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(switch_actor); size_t branch_index = switch_actor->branch_id_to_index_.size(); if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) { @@ -2517,6 +2530,8 @@ void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &grap auto actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); auto switch_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(switch_actor); + for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) { const auto &func_graph = switch_actor->branch_func_graph_[i]; if (func_graph == nullptr) { @@ -2556,6 +2571,7 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap auto actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); auto gather_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(gather_actor); gather_actor->output_branch_arrows_.emplace_back(gather_actor->switch_aid_); } }