Add stack actor for switch actor.
This commit is contained in:
parent
17c942e4d6
commit
35fbe73bdc
|
@ -21,6 +21,9 @@
|
|||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void AbstractActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
MS_EXCEPTION_IF_NULL(input_data->data_);
|
||||
MS_EXCEPTION_IF_NULL(input_data->data_->GetPtr());
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
(void)input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
|
|
|
@ -37,9 +37,11 @@ void StackActor::Init() {
|
|||
// 6. Call input partial.
|
||||
input_datas_num_ = formal_parameters_.size() - input_stack_data_num_ - input_stack_partials_num_;
|
||||
if (input_stack_data_num_ < device_tensor_store_keys_.size() + local_device_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input parameter data num:" << input_stack_data_num_
|
||||
<< " device store num:" << device_tensor_store_keys_.size() << " local device tensor num"
|
||||
<< local_device_tensors_.size() << " for actor:" << GetAID();
|
||||
MS_LOG(EXCEPTION) << "Invalid input stack data num:" << input_stack_data_num_
|
||||
<< " device store num:" << device_tensor_store_keys_.size()
|
||||
<< " local device tensor num:" << local_device_tensors_.size()
|
||||
<< " input stack data num:" << input_stack_data_num_
|
||||
<< " input stack partial num:" << input_stack_partials_num_ << " for actor:" << GetAID();
|
||||
}
|
||||
|
||||
// Fetch the total number of input partial.
|
||||
|
@ -63,8 +65,8 @@ void StackActor::Init() {
|
|||
if (input_stack_data_num_ + input_stack_partials_num_ + input_datas_num_ + input_partials_num_ +
|
||||
device_tensor_store_keys_.size() + local_device_tensors_.size() !=
|
||||
formal_parameters_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input num, input parameter data num:" << input_stack_data_num_
|
||||
<< " input parameter partial num:" << input_stack_partials_num_
|
||||
MS_LOG(EXCEPTION) << "Invalid input num, input stack data num:" << input_stack_data_num_
|
||||
<< " input stack partial num:" << input_stack_partials_num_
|
||||
<< " input data num:" << input_datas_num_ << " input partial num:" << input_partials_num_
|
||||
<< " device tensor store size:" << device_tensor_store_keys_.size()
|
||||
<< " need total size:" << formal_parameters_.size() << " for actor:" << GetAID();
|
||||
|
|
|
@ -39,8 +39,8 @@ void SwitchActor::Init() {
|
|||
}
|
||||
auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
(void)output_data_.emplace_back(std::move(data));
|
||||
(void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(data.get());
|
||||
(void)output_data_.emplace_back(std::move(data));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -689,6 +689,56 @@ void AddFormalToRealParameter(const AnfNodePtr &formal_parameter, const AnfNodeP
|
|||
(*formal_to_real_parameters)[{formal_parameter, i}].insert(real_parameters.begin(), real_parameters.end());
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively traverse the input to confirm whether there is an input of recursive call.
|
||||
bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes,
|
||||
std::set<AnfNodePtr> unrecursion_call_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(checked_nodes);
|
||||
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
|
||||
return true;
|
||||
}
|
||||
checked_nodes->emplace(node);
|
||||
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &inputs = cnode->inputs();
|
||||
for (const auto &input : inputs) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if ((AnfAlgo::IsCallNode(input) && unrecursion_call_nodes.find(input) == unrecursion_call_nodes.end()) ||
|
||||
(!IsFirstControlNode(input, checked_nodes, unrecursion_call_nodes))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Get the level of the control node, recursively traverse all the inputs of the node, and find the largest level
|
||||
// among them.
|
||||
size_t ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes,
|
||||
const mindspore::HashMap<AnfNodePtr, size_t> &node_to_level) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(checked_nodes);
|
||||
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
|
||||
return 0;
|
||||
}
|
||||
checked_nodes->emplace(node);
|
||||
|
||||
auto iter = node_to_level.find(node);
|
||||
if (iter != node_to_level.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
size_t level = 0;
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &inputs = cnode->inputs();
|
||||
for (const auto &input : inputs) {
|
||||
size_t tmp_level = ParseControlNodeLevel(input, checked_nodes, node_to_level);
|
||||
level = (tmp_level > level ? tmp_level : level);
|
||||
}
|
||||
return level;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
|
||||
|
@ -870,6 +920,8 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
|||
|
||||
ParseNeedStackKernelGraph(kernel_graph_to_device_contexts);
|
||||
|
||||
ParseNodeLevel(control_nodes);
|
||||
|
||||
ParseNeedStackControlNode(control_nodes);
|
||||
|
||||
ParseFormalToRealParameter(control_nodes);
|
||||
|
@ -961,7 +1013,7 @@ bool ControlNodeParser::IsRecursionKernelGraph(const KernelGraphPtr &graph) {
|
|||
MS_LOG(EXCEPTION) << "Invalid kernel graph:" << graph->ToString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(group_info_iter->second);
|
||||
if (!group_info_iter->second->is_call_input_) {
|
||||
if (!group_info_iter->second->need_stack_) {
|
||||
return false;
|
||||
}
|
||||
for (const auto &front_input_node : group_info_iter->second->front_input_nodes_) {
|
||||
|
@ -1303,7 +1355,7 @@ bool ControlNodeParser::IsCallInputKernelGraph(KernelGraph *const graph) {
|
|||
bool ControlNodeParser::IsCallInputKernelGraphGroup(const std::string &group_name) {
|
||||
for (const auto &graph_group : kernel_graph_group_infos_) {
|
||||
if (group_name.find(graph_group->group_name_) != std ::string::npos) {
|
||||
return graph_group->is_call_input_;
|
||||
return graph_group->need_stack_;
|
||||
}
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Invalid kernel graph group name:" << group_name;
|
||||
|
@ -1697,28 +1749,6 @@ AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNod
|
|||
return sub_front_node_to_root_front_node_[sub_front_node];
|
||||
}
|
||||
|
||||
bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes,
|
||||
std::set<AnfNodePtr> unrecursion_call_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(checked_nodes);
|
||||
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
|
||||
return true;
|
||||
}
|
||||
checked_nodes->emplace(node);
|
||||
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &inputs = cnode->inputs();
|
||||
for (const auto &input : inputs) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if ((AnfAlgo::IsCallNode(input) && unrecursion_call_nodes.find(input) == unrecursion_call_nodes.end()) ||
|
||||
(!IsFirstControlNode(input, checked_nodes, unrecursion_call_nodes))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
for (const auto &control_node : control_nodes) {
|
||||
std::set<AnfNodePtr> checked_nodes;
|
||||
|
@ -1806,10 +1836,10 @@ void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr>
|
|||
if (call_input_num != 0 && (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimDepend))) {
|
||||
need_stack_control_nodes_.emplace(control_node);
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
|
||||
auto input_with_indexs = FetchInputNodeByCNode(control_node);
|
||||
if (std::any_of(input_with_indexs.begin(), input_with_indexs.end(),
|
||||
[this](const auto &input_with_index) { return IsRecursionCallNode(input_with_index.first); })) {
|
||||
} else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) ||
|
||||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
|
||||
if (!IsInputInSameLevel(control_node)) {
|
||||
need_stack_control_nodes_.emplace(control_node);
|
||||
MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString();
|
||||
}
|
||||
|
@ -1850,7 +1880,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
|
|||
continue;
|
||||
}
|
||||
if (AnfAlgo::IsCallNode(front_node_with_index.first)) {
|
||||
kernel_graph_group_info->is_call_input_ = true;
|
||||
kernel_graph_group_info->need_stack_ = true;
|
||||
}
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(front_node_with_index.first, prim::kPrimTupleGetItem)) {
|
||||
|
@ -1877,7 +1907,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
|
|||
}
|
||||
|
||||
kernel_graphs_to_group_info_[kernel_graph] = kernel_graph_group_info;
|
||||
if (kernel_graph_group_info->is_call_input_) {
|
||||
if (kernel_graph_group_info->need_stack_) {
|
||||
call_input_kernel_graphs_.emplace(kernel_graph.get());
|
||||
}
|
||||
}
|
||||
|
@ -1890,6 +1920,97 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
|
|||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
size_t level = 0;
|
||||
// 1. Parse levels of control nodes.
|
||||
for (const auto &control_node : control_nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
|
||||
node_to_level_[control_node] = level;
|
||||
level = 0;
|
||||
const auto &func_graph = control_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto ¶meters = func_graph->parameters();
|
||||
for (const auto ¶meter : parameters) {
|
||||
node_to_level_[parameter] = level;
|
||||
}
|
||||
continue;
|
||||
} else if (AnfAlgo::IsCallNode(control_node) && IsRecursionCallNode(control_node)) {
|
||||
++level;
|
||||
node_to_level_[control_node] = level;
|
||||
} else {
|
||||
std::set<AnfNodePtr> checked_nodes;
|
||||
node_to_level_[control_node] = ParseControlNodeLevel(control_node, &checked_nodes, node_to_level_);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Parse the levels of kernel graph outputs.
|
||||
for (const auto &kernel_graph_group_info : kernel_graph_group_infos_) {
|
||||
level = 0;
|
||||
for (const auto &front_input_node : kernel_graph_group_info->front_input_nodes_) {
|
||||
const auto &input_node = front_input_node.first.first;
|
||||
auto iter = node_to_level_.find(input_node);
|
||||
if (iter != node_to_level_.end() && level < iter->second) {
|
||||
level = iter->second;
|
||||
}
|
||||
}
|
||||
for (const auto &front_output_node : kernel_graph_group_info->front_output_nodes_) {
|
||||
const auto &output_node = front_output_node.first.first;
|
||||
node_to_level_[output_node] = level;
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the levels of kernel graph groups.
|
||||
for (const auto &kernel_graph_group_info : kernel_graph_group_infos_) {
|
||||
size_t max_level = 0;
|
||||
for (const auto &front_input_node : kernel_graph_group_info->front_input_nodes_) {
|
||||
const auto &input_node = front_input_node.first.first;
|
||||
auto iter = node_to_level_.find(input_node);
|
||||
if (iter == node_to_level_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get level by input node:" << input_node->DebugString()
|
||||
<< " for kernel graph:" << kernel_graph_group_info->group_name_;
|
||||
}
|
||||
max_level = (max_level > iter->second ? max_level : iter->second);
|
||||
}
|
||||
if (max_level > 0) {
|
||||
kernel_graph_group_info->need_stack_ = true;
|
||||
kernel_graph_group_info->level_ = max_level;
|
||||
for (const auto &kernel_graph : kernel_graph_group_info->graphs_) {
|
||||
call_input_kernel_graphs_.emplace(kernel_graph.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ControlNodeParser::IsInputInSameLevel(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto input_with_indexes = FetchInputNodeByCNode(node);
|
||||
size_t level = SIZE_MAX;
|
||||
for (const auto &input_with_index : input_with_indexes) {
|
||||
auto input_node = input_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<ValueNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto iter = node_to_level_.find(input_node);
|
||||
if (iter == node_to_level_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to find level by input:" << input_node->DebugString()
|
||||
<< " for node:" << node->DebugString();
|
||||
}
|
||||
if (level == SIZE_MAX) {
|
||||
level = iter->second;
|
||||
continue;
|
||||
}
|
||||
if (level != iter->second) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ControlNodeParser::CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context) {
|
||||
MS_EXCEPTION_IF_NULL(default_context);
|
||||
for (const auto ¶meter : root_graph_parameters_) {
|
||||
|
|
|
@ -82,8 +82,11 @@ using CallNodeToFuncGraph = mindspore::HashMap<AnfNodePtr, std::set<FuncGraphPtr
|
|||
using KernelGraphToDeviceContext = mindspore::HashMap<KernelGraphPtr, DeviceContext *>;
|
||||
// In the control flow, heterogeneous kernel graphs need to be reconnected in the same group, and the kernel graph
|
||||
// group info is used to store the inputs and outputs of the group.
|
||||
// Need stack indicates whether a stack actor needs to be created for the group.
|
||||
// Level indicates the level of the output of the graph in the group.
|
||||
struct KernelGraphGroupInfo {
|
||||
bool is_call_input_;
|
||||
bool need_stack_{0};
|
||||
size_t level_;
|
||||
std::string group_name_;
|
||||
std::set<KernelGraphPtr> graphs_;
|
||||
std::map<KernelWithIndex, const DeviceContext *> front_input_nodes_;
|
||||
|
@ -128,6 +131,7 @@ class ControlNodeParser {
|
|||
// If there is a recursive call node in the input of the kernel graph, the graph is recursive.
|
||||
bool IsRecursionKernelGraph(const KernelGraphPtr &graph);
|
||||
bool IsSameKernelGraphGroup(const AnfNodePtr &node, const KernelGraphPtr &graph);
|
||||
bool IsInputInSameLevel(const AnfNodePtr &node);
|
||||
|
||||
const std::vector<AnfNodePtr> &control_node_parameters() const { return control_node_parameters_; }
|
||||
const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; }
|
||||
|
@ -217,6 +221,8 @@ class ControlNodeParser {
|
|||
// When a control node or kernel graph has input that is a call node, you need to add a stack actor for it.
|
||||
void ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes);
|
||||
void ParseNeedStackKernelGraph(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts);
|
||||
// Parse the level of inputs and outputs of graphs and all control nodes.
|
||||
void ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device
|
||||
// tensor needs to be created for it.
|
||||
void CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context);
|
||||
|
@ -241,6 +247,15 @@ class ControlNodeParser {
|
|||
// id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id
|
||||
// to its output switch actor.
|
||||
mindspore::HashMap<AnfNodePtr, int> call_node_to_branch_id_;
|
||||
// Level indicates that the input of the node depends on the number of the recursive call node in the funcgraph.
|
||||
// During graph scheduler, the input needs to be graded according to the input's dependence on the recursive call
|
||||
// node, and according to this level, the lower-level inputs are pushed in the stack actor. When arranging, first
|
||||
// sort the call nodes in the funcgraph according to their topological relationships, and then confirm the
|
||||
// dependencies of other nodes on these call nodes in turn.
|
||||
// For example, the dependencies are a -> b, b -> d, c -> d, where b is a call node, then the level of a and c is 0,
|
||||
// and the level of bd is 1, then since d has inputs with different levels of b and c, it is necessary to add a
|
||||
// stack to d.
|
||||
mindspore::HashMap<AnfNodePtr, size_t> node_to_level_;
|
||||
CallNodeToFuncGraph call_node_to_func_graphs_;
|
||||
// host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
|
||||
// When initializing the weights, all related weights need to be recorded as the same device tensor.
|
||||
|
|
|
@ -81,16 +81,6 @@ void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set<AnfNodePtr>
|
|||
}
|
||||
}
|
||||
|
||||
bool IsControlFlowArrow(const ControlNodeParserPtr &parser, const KernelGraphPtr &graph, const AnfNodePtr &from_node) {
|
||||
MS_EXCEPTION_IF_NULL(parser);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(from_node);
|
||||
bool is_call_input_kernl_graph = parser->IsCallInputKernelGraph(graph.get());
|
||||
return ((!is_call_input_kernl_graph) && ((from_node == nullptr) || (!from_node->isa<Parameter>()))) ||
|
||||
(from_node != nullptr && IsPersistentDeviceTensor(from_node)) ||
|
||||
(from_node != nullptr && parser->IsSameKernelGraphGroup(from_node, graph));
|
||||
}
|
||||
|
||||
// Parameter and ref node can not copy the device tensor.
|
||||
bool is_need_copy_device_tensor(const AnfNodePtr &backend_node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(backend_node);
|
||||
|
@ -317,10 +307,10 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
|
|||
|
||||
// Create a corresponding stack actor for each kernel graph that has a call node as input.
|
||||
for (const auto &kernel_graph_group_info : parser->kernel_graph_group_infos_) {
|
||||
if (!kernel_graph_group_info->is_call_input_) {
|
||||
if (!kernel_graph_group_info->need_stack_) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &actor_name = kernel_graph_group_info->group_name_ + kStackActorNameSuffix;
|
||||
size_t input_parameter_data_num = 0;
|
||||
std::vector<const DeviceContext *> device_contexts;
|
||||
std::vector<KernelWithIndex> formal_parameters;
|
||||
|
@ -329,8 +319,13 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
|
|||
// If the input comes from inside funcgraph, put it at the front of the vector, otherwise put it at the end.
|
||||
const auto &from_node = node_with_context.first.first;
|
||||
MS_EXCEPTION_IF_NULL(from_node);
|
||||
const auto &graph = (from_node->isa<CNode>() ? parser->FetchKernelGraphByFrontNode(from_node) : nullptr);
|
||||
if (parser->IsRecursionCallNode(from_node) || (graph != nullptr && parser->IsRecursionKernelGraph(graph))) {
|
||||
auto iter = parser->node_to_level_.find(from_node);
|
||||
if (iter == parser->node_to_level_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get level by from node:" << from_node->DebugString()
|
||||
<< " in graph:" << kernel_graph_group_info->group_name_;
|
||||
}
|
||||
if (iter->second == kernel_graph_group_info->level_ &&
|
||||
(!(parser->IsRootGraphParameter(from_node) && IsPersistentDeviceTensor(from_node)))) {
|
||||
formal_parameters.emplace_back(node_with_context.first);
|
||||
device_contexts.emplace_back(node_with_context.second);
|
||||
} else {
|
||||
|
@ -339,7 +334,6 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
|
|||
input_parameter_data_num++;
|
||||
}
|
||||
}
|
||||
const auto &actor_name = kernel_graph_group_info->group_name_ + kStackActorNameSuffix;
|
||||
const auto &stack_actor = std::make_shared<StackActor>(actor_name, memory_manager_aid_, formal_parameters);
|
||||
stack_actors.emplace_back(stack_actor);
|
||||
stack_actor->device_contexts_.swap(device_contexts);
|
||||
|
@ -360,6 +354,7 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
|
|||
for (const auto &need_stack_control_node : parser->need_stack_control_nodes_) {
|
||||
MS_EXCEPTION_IF_NULL(need_stack_control_node);
|
||||
|
||||
const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix;
|
||||
std::vector<KernelWithIndex> formal_parameters;
|
||||
std::vector<const DeviceContext *> device_contexts;
|
||||
size_t input_parameter_data_num = 0;
|
||||
|
@ -372,12 +367,20 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
control_actor_name = func_graph->ToString() + kExitActorNameSuffix;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimPartial) ||
|
||||
AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimSwitchLayer) ||
|
||||
AnfAlgo::IsCallNode(need_stack_control_node)) {
|
||||
control_actor_name = GetActorName(need_stack_control_node);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid control node:" << need_stack_control_node->DebugString();
|
||||
}
|
||||
|
||||
auto iter = parser->node_to_level_.find(need_stack_control_node);
|
||||
if (iter == parser->node_to_level_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get level for need stack control node:" << need_stack_control_node->DebugString();
|
||||
}
|
||||
size_t control_node_level = iter->second;
|
||||
|
||||
auto actor = FetchActor(control_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto control_actor = dynamic_cast<ControlActor *>(actor);
|
||||
|
@ -392,13 +395,20 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
|
|||
for (size_t i = 0; i < control_actor->formal_parameters_.size(); ++i) {
|
||||
const auto ¶meter = control_actor->formal_parameters_[i];
|
||||
auto device_context = control_actor->device_contexts_[i];
|
||||
const auto &graph =
|
||||
(parameter.first->isa<CNode>() ? parser->FetchKernelGraphByFrontNode(parameter.first) : nullptr);
|
||||
if (parser->IsRecursionCallNode(parameter.first) || (graph != nullptr && parser->IsRecursionKernelGraph(graph))) {
|
||||
if (parameter.first->isa<ValueNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
iter = parser->node_to_level_.find(parameter.first);
|
||||
if (iter == parser->node_to_level_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get level for formal parameter:" << parameter.first->DebugString()
|
||||
<< " for need stack control node:" << need_stack_control_node->DebugString();
|
||||
}
|
||||
|
||||
if (control_node_level == iter->second &&
|
||||
(!(parser->IsRootGraphParameter(parameter.first) && IsPersistentDeviceTensor(parameter.first)))) {
|
||||
formal_parameters.emplace_back(parameter);
|
||||
device_contexts.emplace_back(device_context);
|
||||
} else if (parameter.first->isa<ValueNode>()) {
|
||||
continue;
|
||||
} else {
|
||||
formal_parameters.insert(formal_parameters.begin(), parameter);
|
||||
device_contexts.insert(device_contexts.begin(), device_context);
|
||||
|
@ -415,7 +425,6 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
|
|||
}
|
||||
}
|
||||
// Create stack actor.
|
||||
const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix;
|
||||
const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, memory_manager_aid_, formal_parameters);
|
||||
stack_actor->device_contexts_ = device_contexts;
|
||||
stack_actor->input_stack_data_num_ = input_parameter_data_num;
|
||||
|
@ -495,9 +504,20 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
|||
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
|
||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||
for (auto &switch_actor : control_actor_set->switch_actors_) {
|
||||
for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
|
||||
LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
|
||||
parser);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
if (parser->need_stack_control_nodes_.find(switch_actor->node_) == parser->need_stack_control_nodes_.end()) {
|
||||
for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
|
||||
LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
|
||||
parser);
|
||||
}
|
||||
} else {
|
||||
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
|
||||
auto stack_actor_name = GetActorName(switch_actor->node_) + kStackActorNameSuffix;
|
||||
auto actor = FetchActor(stack_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||
LinkArrowFromStackActor(stack_actor, switch_actor.get());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -601,7 +621,14 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
|
|||
MS_EXCEPTION_IF_NULL(actor);
|
||||
const auto &switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
LinkPartialArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
|
||||
|
||||
const auto &abstract = from_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (abstract->isa<abstract::AbstractFunction>()) {
|
||||
LinkPartialArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
|
||||
} else {
|
||||
LinkDataArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second);
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) {
|
||||
// Link arrow from gather actor
|
||||
const auto &actor_name = GetActorName(from_node);
|
||||
|
@ -934,7 +961,7 @@ void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_
|
|||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto actor_name = kernel_graph->ToString() + kStackActorNameSuffix;
|
||||
auto actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kStackActorNameSuffix;
|
||||
if (!parser->IsCallInputKernelGraph(kernel_graph.get())) {
|
||||
const auto &func_graph = parser->FetchFuncGraphByKernelGraph(kernel_graph.get());
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -1135,10 +1162,8 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
|
|||
auto input_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||
auto input = input_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (sink_input_node_linked.count(input) > 0) {
|
||||
continue;
|
||||
}
|
||||
if ((!input->isa<Parameter>()) || HasAbstractMonad(input) || IsPersistentDeviceTensor(input)) {
|
||||
if (sink_input_node_linked.count(input) > 0 || HasAbstractMonad(input) || parser == nullptr ||
|
||||
(!parser->IsControlFlowDataArrow(graph, input))) {
|
||||
continue;
|
||||
}
|
||||
auto front_node = graph->GetFrontAnfByBackendAnf(input);
|
||||
|
@ -1159,16 +1184,23 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
|
|||
// If the formal parameter is a tuple type, the parameter of the kernel graph will not directly correspond
|
||||
// to the front parameter, but the node in the internal parameter.
|
||||
const auto &from_node = from_node_with_index.first;
|
||||
if (IsControlFlowArrow(parser, graph, from_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fetch actor and link.
|
||||
auto type = FetchKernelTransformType(kernel, graph, {});
|
||||
auto to_actor = FetchActor(type, "", kernel, graph);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
size_t from_index = 0;
|
||||
if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) {
|
||||
const auto &actor_name = GetActorName(from_node);
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
from_actor = dynamic_cast<ControlActor *>(actor);
|
||||
} else {
|
||||
from_index = from_actor->FetchNodePosition(from_node_with_index);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto from_index = from_actor->FetchNodePosition(from_node_with_index);
|
||||
auto to_index = i;
|
||||
if (type == KernelTransformType::kSuperKernelActor) {
|
||||
auto super_kernel_actor = dynamic_cast<SuperKernelActor *>(to_actor);
|
||||
|
|
|
@ -559,7 +559,8 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu
|
|||
MS_EXCEPTION_IF_NULL(cut_node);
|
||||
MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->DebugString();
|
||||
control_nodes_.push_back(cut_node);
|
||||
if (AnfAlgo::IsCallNode(cut_node)) {
|
||||
if (AnfAlgo::IsCallNode(cut_node) || AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitchLayer)) {
|
||||
const auto &func_graph = cut_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>());
|
||||
|
|
Loading…
Reference in New Issue