!28710 Add stack actor for switch actor.

Merge pull request !28710 from gaoyong10/runtime_second12
This commit is contained in:
i-robot 2022-01-10 06:37:46 +00:00 committed by Gitee
commit 9281888ee9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 246 additions and 72 deletions

View File

@ -21,6 +21,9 @@
namespace mindspore {
namespace runtime {
void AbstractActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(input_data);
MS_EXCEPTION_IF_NULL(input_data->data_);
MS_EXCEPTION_IF_NULL(input_data->data_->GetPtr());
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
(void)input_op_datas_[sequential_num].emplace_back(input_data);

View File

@ -37,9 +37,11 @@ void StackActor::Init() {
// 6. Call input partial.
input_datas_num_ = formal_parameters_.size() - input_stack_data_num_ - input_stack_partials_num_;
if (input_stack_data_num_ < device_tensor_store_keys_.size() + local_device_tensors_.size()) {
MS_LOG(EXCEPTION) << "Invalid input parameter data num:" << input_stack_data_num_
<< " device store num:" << device_tensor_store_keys_.size() << " local device tensor num"
<< local_device_tensors_.size() << " for actor:" << GetAID();
MS_LOG(EXCEPTION) << "Invalid input stack data num:" << input_stack_data_num_
<< " device store num:" << device_tensor_store_keys_.size()
<< " local device tensor num:" << local_device_tensors_.size()
<< " input stack data num:" << input_stack_data_num_
<< " input stack partial num:" << input_stack_partials_num_ << " for actor:" << GetAID();
}
// Fetch the total number of input partial.
@ -63,8 +65,8 @@ void StackActor::Init() {
if (input_stack_data_num_ + input_stack_partials_num_ + input_datas_num_ + input_partials_num_ +
device_tensor_store_keys_.size() + local_device_tensors_.size() !=
formal_parameters_.size()) {
MS_LOG(EXCEPTION) << "Invalid input num, input parameter data num:" << input_stack_data_num_
<< " input parameter partial num:" << input_stack_partials_num_
MS_LOG(EXCEPTION) << "Invalid input num, input stack data num:" << input_stack_data_num_
<< " input stack partial num:" << input_stack_partials_num_
<< " input data num:" << input_datas_num_ << " input partial num:" << input_partials_num_
<< " device tensor store size:" << device_tensor_store_keys_.size()
<< " need total size:" << formal_parameters_.size() << " for actor:" << GetAID();

View File

@ -39,8 +39,8 @@ void SwitchActor::Init() {
}
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
MS_EXCEPTION_IF_NULL(data);
(void)output_data_.emplace_back(std::move(data));
(void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(data.get());
(void)output_data_.emplace_back(std::move(data));
}
}

View File

@ -689,6 +689,56 @@ void AddFormalToRealParameter(const AnfNodePtr &formal_parameter, const AnfNodeP
(*formal_to_real_parameters)[{formal_parameter, i}].insert(real_parameters.begin(), real_parameters.end());
}
}
// Recursively traverse the input to confirm whether there is an input of recursive call.
bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes,
std::set<AnfNodePtr> unrecursion_call_nodes) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(checked_nodes);
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
return true;
}
checked_nodes->emplace(node);
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
if ((AnfAlgo::IsCallNode(input) && unrecursion_call_nodes.find(input) == unrecursion_call_nodes.end()) ||
(!IsFirstControlNode(input, checked_nodes, unrecursion_call_nodes))) {
return false;
}
}
return true;
}
// Get the level of the control node, recursively traverse all the inputs of the node, and find the largest level
// among them.
size_t ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes,
const mindspore::HashMap<AnfNodePtr, size_t> &node_to_level) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(checked_nodes);
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
return 0;
}
checked_nodes->emplace(node);
auto iter = node_to_level.find(node);
if (iter != node_to_level.end()) {
return iter->second;
}
size_t level = 0;
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
for (const auto &input : inputs) {
size_t tmp_level = ParseControlNodeLevel(input, checked_nodes, node_to_level);
level = (tmp_level > level ? tmp_level : level);
}
return level;
}
} // namespace
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
@ -870,6 +920,8 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
ParseNeedStackKernelGraph(kernel_graph_to_device_contexts);
ParseNodeLevel(control_nodes);
ParseNeedStackControlNode(control_nodes);
ParseFormalToRealParameter(control_nodes);
@ -961,7 +1013,7 @@ bool ControlNodeParser::IsRecursionKernelGraph(const KernelGraphPtr &graph) {
MS_LOG(EXCEPTION) << "Invalid kernel graph:" << graph->ToString();
}
MS_EXCEPTION_IF_NULL(group_info_iter->second);
if (!group_info_iter->second->is_call_input_) {
if (!group_info_iter->second->need_stack_) {
return false;
}
for (const auto &front_input_node : group_info_iter->second->front_input_nodes_) {
@ -1303,7 +1355,7 @@ bool ControlNodeParser::IsCallInputKernelGraph(KernelGraph *const graph) {
bool ControlNodeParser::IsCallInputKernelGraphGroup(const std::string &group_name) {
for (const auto &graph_group : kernel_graph_group_infos_) {
if (group_name.find(graph_group->group_name_) != std ::string::npos) {
return graph_group->is_call_input_;
return graph_group->need_stack_;
}
}
MS_LOG(EXCEPTION) << "Invalid kernel graph group name:" << group_name;
@ -1697,28 +1749,6 @@ AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNod
return sub_front_node_to_root_front_node_[sub_front_node];
}
bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes,
std::set<AnfNodePtr> unrecursion_call_nodes) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(checked_nodes);
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
return true;
}
checked_nodes->emplace(node);
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
if ((AnfAlgo::IsCallNode(input) && unrecursion_call_nodes.find(input) == unrecursion_call_nodes.end()) ||
(!IsFirstControlNode(input, checked_nodes, unrecursion_call_nodes))) {
return false;
}
}
return true;
}
void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
for (const auto &control_node : control_nodes) {
std::set<AnfNodePtr> checked_nodes;
@ -1806,10 +1836,10 @@ void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr>
if (call_input_num != 0 && (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimDepend))) {
need_stack_control_nodes_.emplace(control_node);
}
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
auto input_with_indexs = FetchInputNodeByCNode(control_node);
if (std::any_of(input_with_indexs.begin(), input_with_indexs.end(),
[this](const auto &input_with_index) { return IsRecursionCallNode(input_with_index.first); })) {
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) ||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
if (!IsInputInSameLevel(control_node)) {
need_stack_control_nodes_.emplace(control_node);
MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString();
}
@ -1850,7 +1880,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
continue;
}
if (AnfAlgo::IsCallNode(front_node_with_index.first)) {
kernel_graph_group_info->is_call_input_ = true;
kernel_graph_group_info->need_stack_ = true;
}
if (AnfAlgo::CheckPrimitiveType(front_node_with_index.first, prim::kPrimTupleGetItem)) {
@ -1877,7 +1907,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
}
kernel_graphs_to_group_info_[kernel_graph] = kernel_graph_group_info;
if (kernel_graph_group_info->is_call_input_) {
if (kernel_graph_group_info->need_stack_) {
call_input_kernel_graphs_.emplace(kernel_graph.get());
}
}
@ -1890,6 +1920,97 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
}
}
void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes) {
size_t level = 0;
// 1. Parse levels of control nodes.
for (const auto &control_node : control_nodes) {
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
node_to_level_[control_node] = level;
level = 0;
const auto &func_graph = control_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
const auto &parameters = func_graph->parameters();
for (const auto &parameter : parameters) {
node_to_level_[parameter] = level;
}
continue;
} else if (AnfAlgo::IsCallNode(control_node) && IsRecursionCallNode(control_node)) {
++level;
node_to_level_[control_node] = level;
} else {
std::set<AnfNodePtr> checked_nodes;
node_to_level_[control_node] = ParseControlNodeLevel(control_node, &checked_nodes, node_to_level_);
}
}
// 2. Parse the levels of kernel graph outputs.
for (const auto &kernel_graph_group_info : kernel_graph_group_infos_) {
level = 0;
for (const auto &front_input_node : kernel_graph_group_info->front_input_nodes_) {
const auto &input_node = front_input_node.first.first;
auto iter = node_to_level_.find(input_node);
if (iter != node_to_level_.end() && level < iter->second) {
level = iter->second;
}
}
for (const auto &front_output_node : kernel_graph_group_info->front_output_nodes_) {
const auto &output_node = front_output_node.first.first;
node_to_level_[output_node] = level;
}
}
// Parse the levels of kernel graph groups.
for (const auto &kernel_graph_group_info : kernel_graph_group_infos_) {
size_t max_level = 0;
for (const auto &front_input_node : kernel_graph_group_info->front_input_nodes_) {
const auto &input_node = front_input_node.first.first;
auto iter = node_to_level_.find(input_node);
if (iter == node_to_level_.end()) {
MS_LOG(EXCEPTION) << "Failed to get level by input node:" << input_node->DebugString()
<< " for kernel graph:" << kernel_graph_group_info->group_name_;
}
max_level = (max_level > iter->second ? max_level : iter->second);
}
if (max_level > 0) {
kernel_graph_group_info->need_stack_ = true;
kernel_graph_group_info->level_ = max_level;
for (const auto &kernel_graph : kernel_graph_group_info->graphs_) {
call_input_kernel_graphs_.emplace(kernel_graph.get());
}
}
}
}
bool ControlNodeParser::IsInputInSameLevel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return true;
}
auto input_with_indexes = FetchInputNodeByCNode(node);
size_t level = SIZE_MAX;
for (const auto &input_with_index : input_with_indexes) {
auto input_node = input_with_index.first;
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<ValueNode>()) {
continue;
}
auto iter = node_to_level_.find(input_node);
if (iter == node_to_level_.end()) {
MS_LOG(EXCEPTION) << "Failed to find level by input:" << input_node->DebugString()
<< " for node:" << node->DebugString();
}
if (level == SIZE_MAX) {
level = iter->second;
continue;
}
if (level != iter->second) {
return false;
}
}
return true;
}
void ControlNodeParser::CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context) {
MS_EXCEPTION_IF_NULL(default_context);
for (const auto &parameter : root_graph_parameters_) {

View File

@ -82,8 +82,11 @@ using CallNodeToFuncGraph = mindspore::HashMap<AnfNodePtr, std::set<FuncGraphPtr
using KernelGraphToDeviceContext = mindspore::HashMap<KernelGraphPtr, DeviceContext *>;
// In the control flow, heterogeneous kernel graphs need to be reconnected in the same group, and the kernel graph
// group info is used to store the inputs and outputs of the group.
// Need stack indicates whether a stack actor needs to be created for the group.
// Level indicates the level of the output of the graph in the group.
struct KernelGraphGroupInfo {
bool is_call_input_;
bool need_stack_{0};
size_t level_;
std::string group_name_;
std::set<KernelGraphPtr> graphs_;
std::map<KernelWithIndex, const DeviceContext *> front_input_nodes_;
@ -128,6 +131,7 @@ class ControlNodeParser {
// If there is a recursive call node in the input of the kernel graph, the graph is recursive.
bool IsRecursionKernelGraph(const KernelGraphPtr &graph);
bool IsSameKernelGraphGroup(const AnfNodePtr &node, const KernelGraphPtr &graph);
bool IsInputInSameLevel(const AnfNodePtr &node);
const std::vector<AnfNodePtr> &control_node_parameters() const { return control_node_parameters_; }
const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; }
@ -217,6 +221,8 @@ class ControlNodeParser {
// When a control node or kernel graph has input that is a call node, you need to add a stack actor for it.
void ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes);
void ParseNeedStackKernelGraph(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts);
// Parse the level of inputs and outputs of graphs and all control nodes.
void ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes);
// When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device
// tensor needs to be created for it.
void CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context);
@ -241,6 +247,15 @@ class ControlNodeParser {
// id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id
// to its output switch actor.
mindspore::HashMap<AnfNodePtr, int> call_node_to_branch_id_;
// Level indicates that the input of the node depends on the number of the recursive call node in the funcgraph.
// During graph scheduler, the input needs to be graded according to the input's dependence on the recursive call
// node, and according to this level, the lower-level inputs are pushed in the stack actor. When arranging, first
// sort the call nodes in the funcgraph according to their topological relationships, and then confirm the
// dependencies of other nodes on these call nodes in turn.
// For example, the dependencies are a -> b, b -> d, c -> d, where b is a call node, then the level of a and c is 0,
// and the level of bd is 1, then since d has inputs with different levels of b and c, it is necessary to add a
// stack to d.
mindspore::HashMap<AnfNodePtr, size_t> node_to_level_;
CallNodeToFuncGraph call_node_to_func_graphs_;
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
// When initializing the weights, all related weights need to be recorded as the same device tensor.

View File

@ -81,16 +81,6 @@ void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set<AnfNodePtr>
}
}
bool IsControlFlowArrow(const ControlNodeParserPtr &parser, const KernelGraphPtr &graph, const AnfNodePtr &from_node) {
MS_EXCEPTION_IF_NULL(parser);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(from_node);
bool is_call_input_kernl_graph = parser->IsCallInputKernelGraph(graph.get());
return ((!is_call_input_kernl_graph) && ((from_node == nullptr) || (!from_node->isa<Parameter>()))) ||
(from_node != nullptr && IsPersistentDeviceTensor(from_node)) ||
(from_node != nullptr && parser->IsSameKernelGraphGroup(from_node, graph));
}
// Parameter and ref node can not copy the device tensor.
bool is_need_copy_device_tensor(const AnfNodePtr &backend_node, size_t index) {
MS_EXCEPTION_IF_NULL(backend_node);
@ -317,10 +307,10 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
// Create a corresponding stack actor for each kernel graph that has a call node as input.
for (const auto &kernel_graph_group_info : parser->kernel_graph_group_infos_) {
if (!kernel_graph_group_info->is_call_input_) {
if (!kernel_graph_group_info->need_stack_) {
continue;
}
const auto &actor_name = kernel_graph_group_info->group_name_ + kStackActorNameSuffix;
size_t input_parameter_data_num = 0;
std::vector<const DeviceContext *> device_contexts;
std::vector<KernelWithIndex> formal_parameters;
@ -329,8 +319,13 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
// If the input comes from inside funcgraph, put it at the front of the vector, otherwise put it at the end.
const auto &from_node = node_with_context.first.first;
MS_EXCEPTION_IF_NULL(from_node);
const auto &graph = (from_node->isa<CNode>() ? parser->FetchKernelGraphByFrontNode(from_node) : nullptr);
if (parser->IsRecursionCallNode(from_node) || (graph != nullptr && parser->IsRecursionKernelGraph(graph))) {
auto iter = parser->node_to_level_.find(from_node);
if (iter == parser->node_to_level_.end()) {
MS_LOG(EXCEPTION) << "Failed to get level by from node:" << from_node->DebugString()
<< " in graph:" << kernel_graph_group_info->group_name_;
}
if (iter->second == kernel_graph_group_info->level_ &&
(!(parser->IsRootGraphParameter(from_node) && IsPersistentDeviceTensor(from_node)))) {
formal_parameters.emplace_back(node_with_context.first);
device_contexts.emplace_back(node_with_context.second);
} else {
@ -339,7 +334,6 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
input_parameter_data_num++;
}
}
const auto &actor_name = kernel_graph_group_info->group_name_ + kStackActorNameSuffix;
const auto &stack_actor = std::make_shared<StackActor>(actor_name, memory_manager_aid_, formal_parameters);
stack_actors.emplace_back(stack_actor);
stack_actor->device_contexts_.swap(device_contexts);
@ -360,6 +354,7 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
for (const auto &need_stack_control_node : parser->need_stack_control_nodes_) {
MS_EXCEPTION_IF_NULL(need_stack_control_node);
const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix;
std::vector<KernelWithIndex> formal_parameters;
std::vector<const DeviceContext *> device_contexts;
size_t input_parameter_data_num = 0;
@ -372,12 +367,20 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
MS_EXCEPTION_IF_NULL(func_graph);
control_actor_name = func_graph->ToString() + kExitActorNameSuffix;
} else if (AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimPartial) ||
AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimSwitchLayer) ||
AnfAlgo::IsCallNode(need_stack_control_node)) {
control_actor_name = GetActorName(need_stack_control_node);
} else {
MS_LOG(EXCEPTION) << "Invalid control node:" << need_stack_control_node->DebugString();
}
auto iter = parser->node_to_level_.find(need_stack_control_node);
if (iter == parser->node_to_level_.end()) {
MS_LOG(EXCEPTION) << "Failed to get level for need stack control node:" << need_stack_control_node->DebugString();
}
size_t control_node_level = iter->second;
auto actor = FetchActor(control_actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto control_actor = dynamic_cast<ControlActor *>(actor);
@ -392,13 +395,20 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
for (size_t i = 0; i < control_actor->formal_parameters_.size(); ++i) {
const auto &parameter = control_actor->formal_parameters_[i];
auto device_context = control_actor->device_contexts_[i];
const auto &graph =
(parameter.first->isa<CNode>() ? parser->FetchKernelGraphByFrontNode(parameter.first) : nullptr);
if (parser->IsRecursionCallNode(parameter.first) || (graph != nullptr && parser->IsRecursionKernelGraph(graph))) {
if (parameter.first->isa<ValueNode>()) {
continue;
}
iter = parser->node_to_level_.find(parameter.first);
if (iter == parser->node_to_level_.end()) {
MS_LOG(EXCEPTION) << "Failed to get level for formal parameter:" << parameter.first->DebugString()
<< " for need stack control node:" << need_stack_control_node->DebugString();
}
if (control_node_level == iter->second &&
(!(parser->IsRootGraphParameter(parameter.first) && IsPersistentDeviceTensor(parameter.first)))) {
formal_parameters.emplace_back(parameter);
device_contexts.emplace_back(device_context);
} else if (parameter.first->isa<ValueNode>()) {
continue;
} else {
formal_parameters.insert(formal_parameters.begin(), parameter);
device_contexts.insert(device_contexts.begin(), device_context);
@ -415,7 +425,6 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
}
}
// Create stack actor.
const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix;
const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, memory_manager_aid_, formal_parameters);
stack_actor->device_contexts_ = device_contexts;
stack_actor->input_stack_data_num_ = input_parameter_data_num;
@ -495,9 +504,20 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
const auto &parser = graph_compiler_info.control_node_parser_;
for (auto &switch_actor : control_actor_set->switch_actors_) {
for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
parser);
MS_EXCEPTION_IF_NULL(switch_actor);
if (parser->need_stack_control_nodes_.find(switch_actor->node_) == parser->need_stack_control_nodes_.end()) {
for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
parser);
}
} else {
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
auto stack_actor_name = GetActorName(switch_actor->node_) + kStackActorNameSuffix;
auto actor = FetchActor(stack_actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto stack_actor = dynamic_cast<StackActor *>(actor);
MS_EXCEPTION_IF_NULL(stack_actor);
LinkArrowFromStackActor(stack_actor, switch_actor.get());
}
}
@ -601,7 +621,14 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
MS_EXCEPTION_IF_NULL(actor);
const auto &switch_actor = dynamic_cast<SwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(switch_actor);
LinkPartialArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
const auto &abstract = from_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<abstract::AbstractFunction>()) {
LinkPartialArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
} else {
LinkDataArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
}
} else if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) {
// Link arrow from gather actor
const auto &actor_name = GetActorName(from_node);
@ -934,7 +961,7 @@ void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_
continue;
}
MS_EXCEPTION_IF_NULL(kernel_graph);
auto actor_name = kernel_graph->ToString() + kStackActorNameSuffix;
auto actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kStackActorNameSuffix;
if (!parser->IsCallInputKernelGraph(kernel_graph.get())) {
const auto &func_graph = parser->FetchFuncGraphByKernelGraph(kernel_graph.get());
MS_EXCEPTION_IF_NULL(func_graph);
@ -1135,10 +1162,8 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
auto input_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
auto input = input_with_index.first;
MS_EXCEPTION_IF_NULL(input);
if (sink_input_node_linked.count(input) > 0) {
continue;
}
if ((!input->isa<Parameter>()) || HasAbstractMonad(input) || IsPersistentDeviceTensor(input)) {
if (sink_input_node_linked.count(input) > 0 || HasAbstractMonad(input) || parser == nullptr ||
(!parser->IsControlFlowDataArrow(graph, input))) {
continue;
}
auto front_node = graph->GetFrontAnfByBackendAnf(input);
@ -1159,16 +1184,23 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
// If the formal parameter is a tuple type, the parameter of the kernel graph will not directly correspond
// to the front parameter, but the node in the internal parameter.
const auto &from_node = from_node_with_index.first;
if (IsControlFlowArrow(parser, graph, from_node)) {
continue;
}
// Fetch actor and link.
auto type = FetchKernelTransformType(kernel, graph, {});
auto to_actor = FetchActor(type, "", kernel, graph);
MS_EXCEPTION_IF_NULL(to_actor);
size_t from_index = 0;
if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) {
const auto &actor_name = GetActorName(from_node);
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
from_actor = dynamic_cast<ControlActor *>(actor);
} else {
from_index = from_actor->FetchNodePosition(from_node_with_index);
}
MS_EXCEPTION_IF_NULL(from_actor);
auto from_index = from_actor->FetchNodePosition(from_node_with_index);
auto to_index = i;
if (type == KernelTransformType::kSuperKernelActor) {
auto super_kernel_actor = dynamic_cast<SuperKernelActor *>(to_actor);

View File

@ -558,7 +558,8 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu
MS_EXCEPTION_IF_NULL(cut_node);
MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->DebugString();
control_nodes_.push_back(cut_node);
if (AnfAlgo::IsCallNode(cut_node)) {
if (AnfAlgo::IsCallNode(cut_node) || AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitchLayer)) {
const auto &func_graph = cut_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>());