Parse kernel level by kernel graph.
This commit is contained in:
parent
b85161cb94
commit
911f664b0e
|
@ -358,6 +358,10 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
|
||||||
if (abstract != nullptr) {
|
if (abstract != nullptr) {
|
||||||
(*abstract) = sub_abstracts[GetTupleGetItemOutIndex(cnode)];
|
(*abstract) = sub_abstracts[GetTupleGetItemOutIndex(cnode)];
|
||||||
MS_EXCEPTION_IF_NULL((*abstract));
|
MS_EXCEPTION_IF_NULL((*abstract));
|
||||||
|
} else {
|
||||||
|
// In recursion of getitem node, the index of the first input of its real node is returned.
|
||||||
|
// When the recursion ends, the outermost index needs to be accumulated.
|
||||||
|
real_index += index;
|
||||||
}
|
}
|
||||||
return {item_with_index_tmp.first, real_index};
|
return {item_with_index_tmp.first, real_index};
|
||||||
}
|
}
|
||||||
|
|
|
@ -407,7 +407,7 @@ void FetchAllExecutionFunction(const FuncGraphPtr &func_graph, std::set<FuncGrap
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isValidMonadNode(const AnfNodePtr &node) {
|
bool IsValidMonadNode(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
return node->isa<ValueNode>() || node->isa<Parameter>() || AnfAlgo::IsCallNode(node);
|
return node->isa<ValueNode>() || node->isa<Parameter>() || AnfAlgo::IsCallNode(node);
|
||||||
}
|
}
|
||||||
|
@ -419,7 +419,7 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
||||||
const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
||||||
const auto &real_node = real_node_with_index.first;
|
const auto &real_node = real_node_with_index.first;
|
||||||
MS_EXCEPTION_IF_NULL(real_node);
|
MS_EXCEPTION_IF_NULL(real_node);
|
||||||
if (isValidMonadNode(real_node)) {
|
if (IsValidMonadNode(real_node)) {
|
||||||
return {real_node_with_index};
|
return {real_node_with_index};
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "Invalid monad node:" << real_node->DebugString();
|
MS_LOG(EXCEPTION) << "Invalid monad node:" << real_node->DebugString();
|
||||||
|
@ -561,33 +561,6 @@ bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_no
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the level of the control node, recursively traverse all the inputs of the node, and find the largest level
|
|
||||||
// among them.
|
|
||||||
size_t ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes,
|
|
||||||
const mindspore::HashMap<AnfNodePtr, size_t> &node_to_level) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(checked_nodes);
|
|
||||||
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
(void)checked_nodes->emplace(node);
|
|
||||||
|
|
||||||
auto iter = node_to_level.find(node);
|
|
||||||
if (iter != node_to_level.end()) {
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t level = 0;
|
|
||||||
const auto &cnode = node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
const auto &inputs = cnode->inputs();
|
|
||||||
for (const auto &input : inputs) {
|
|
||||||
size_t tmp_level = ParseControlNodeLevel(input, checked_nodes, node_to_level);
|
|
||||||
level = (tmp_level > level ? tmp_level : level);
|
|
||||||
}
|
|
||||||
return level;
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
|
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
|
||||||
|
@ -755,7 +728,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
||||||
|
|
||||||
ParseUnRecursionCallNode();
|
ParseUnRecursionCallNode();
|
||||||
|
|
||||||
ParseNeedStackKernelGraph(kernel_graph_to_device_contexts);
|
ParseKernelGraphGroup(kernel_graph_to_device_contexts);
|
||||||
|
|
||||||
ParseNodeLevel(control_nodes);
|
ParseNodeLevel(control_nodes);
|
||||||
|
|
||||||
|
@ -841,14 +814,21 @@ bool ControlNodeParser::IsRootGraphPersistentDeviceTensor(const AnfNodePtr &node
|
||||||
return find(root_graph_parameters_.begin(), root_graph_parameters_.end(), node) != root_graph_parameters_.end();
|
return find(root_graph_parameters_.begin(), root_graph_parameters_.end(), node) != root_graph_parameters_.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ControlNodeParser::IsNeedStackControlNode(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!(node->isa<CNode>())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return need_stack_control_nodes_.find(node) != need_stack_control_nodes_.end();
|
||||||
|
}
|
||||||
|
|
||||||
bool ControlNodeParser::IsRecursionCallNode(const AnfNodePtr &node) {
|
bool ControlNodeParser::IsRecursionCallNode(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (!AnfAlgo::IsCallNode(node)) {
|
if (!AnfAlgo::IsCallNode(node)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return (find(unrecursion_call_nodes_.begin(), unrecursion_call_nodes_.end(), node) ==
|
return find(unrecursion_call_nodes_.begin(), unrecursion_call_nodes_.end(), node) == unrecursion_call_nodes_.end();
|
||||||
unrecursion_call_nodes_.end()) ||
|
|
||||||
(need_stack_control_nodes_.find(node) != need_stack_control_nodes_.end());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ControlNodeParser::IsRecursionKernelGraph(const KernelGraphPtr &graph) {
|
bool ControlNodeParser::IsRecursionKernelGraph(const KernelGraphPtr &graph) {
|
||||||
|
@ -1852,6 +1832,16 @@ void CollectEffectiveOutputByGraph(const KernelGraphPtr &graph, DeviceContext *c
|
||||||
backend_to_front.second.first->isa<ValueNode>()) {
|
backend_to_front.second.first->isa<ValueNode>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip the function input.
|
||||||
|
const auto &abstract = backend_to_front.second.first->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
const auto &real_abstract = FetchAbstractByIndex(abstract, backend_to_front.second.second);
|
||||||
|
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||||
|
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString()
|
MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString()
|
||||||
<< " add front output node:" << backend_to_front.second.first->DebugString()
|
<< " add front output node:" << backend_to_front.second.first->DebugString()
|
||||||
<< " index:" << backend_to_front.second.second
|
<< " index:" << backend_to_front.second.second
|
||||||
|
@ -1861,7 +1851,7 @@ void CollectEffectiveOutputByGraph(const KernelGraphPtr &graph, DeviceContext *c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts) {
|
void ControlNodeParser::ParseKernelGraphGroup(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts) {
|
||||||
for (const auto &func_graph_to_kernel_graph_groups : func_graph_to_kernel_graph_groups_) {
|
for (const auto &func_graph_to_kernel_graph_groups : func_graph_to_kernel_graph_groups_) {
|
||||||
for (const auto &kernel_graph_group : func_graph_to_kernel_graph_groups.second) {
|
for (const auto &kernel_graph_group : func_graph_to_kernel_graph_groups.second) {
|
||||||
if (kernel_graph_group.empty()) {
|
if (kernel_graph_group.empty()) {
|
||||||
|
@ -1903,6 +1893,49 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t ControlNodeParser::ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(checked_nodes);
|
||||||
|
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
(void)checked_nodes->emplace(node);
|
||||||
|
|
||||||
|
auto iter = node_to_level_.find(node);
|
||||||
|
if (iter != node_to_level_.end()) {
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t level = 0;
|
||||||
|
const auto &kernel_graph = FetchKernelGraphByFrontNode(node);
|
||||||
|
if (kernel_graph == nullptr) {
|
||||||
|
// If the kernel graph is not found, it means that the input does not come from the kernel graph, then
|
||||||
|
// just continue to traverse the input.
|
||||||
|
const auto &cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
const auto &inputs = cnode->inputs();
|
||||||
|
for (const auto &input : inputs) {
|
||||||
|
size_t tmp_level = ParseControlNodeLevel(input, checked_nodes);
|
||||||
|
level = (tmp_level > level ? tmp_level : level);
|
||||||
|
}
|
||||||
|
return level;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the input comes from the kernel graph, you need to check all the graph's input, not just the node's input.
|
||||||
|
auto group_info_iter = kernel_graphs_to_group_info_.find(kernel_graph);
|
||||||
|
if (group_info_iter == kernel_graphs_to_group_info_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to get kernel graph group info for graph:" << kernel_graph->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &inputs = group_info_iter->second->front_input_nodes_;
|
||||||
|
for (const auto &input : inputs) {
|
||||||
|
const auto &node = input.first.first;
|
||||||
|
size_t tmp_level = ParseControlNodeLevel(node, checked_nodes);
|
||||||
|
level = (tmp_level > level ? tmp_level : level);
|
||||||
|
}
|
||||||
|
return level;
|
||||||
|
}
|
||||||
|
|
||||||
void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes) {
|
void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes) {
|
||||||
size_t level = 0;
|
size_t level = 0;
|
||||||
// 1. Parse levels of control nodes.
|
// 1. Parse levels of control nodes.
|
||||||
|
@ -1917,12 +1950,12 @@ void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_no
|
||||||
node_to_level_[parameter] = level;
|
node_to_level_[parameter] = level;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
} else if (AnfAlgo::IsCallNode(control_node) && IsRecursionCallNode(control_node)) {
|
} else if (IsRecursionCallNode(control_node)) {
|
||||||
++level;
|
++level;
|
||||||
node_to_level_[control_node] = level;
|
node_to_level_[control_node] = level;
|
||||||
} else {
|
} else {
|
||||||
std::set<AnfNodePtr> checked_nodes;
|
std::set<AnfNodePtr> checked_nodes;
|
||||||
node_to_level_[control_node] = ParseControlNodeLevel(control_node, &checked_nodes, node_to_level_);
|
node_to_level_[control_node] = ParseControlNodeLevel(control_node, &checked_nodes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -134,6 +134,7 @@ class ControlNodeParser {
|
||||||
// persistent and real parameters passed.
|
// persistent and real parameters passed.
|
||||||
bool IsRootGraphPersistentDeviceTensor(const AnfNodePtr &node);
|
bool IsRootGraphPersistentDeviceTensor(const AnfNodePtr &node);
|
||||||
bool IsRecursionCallNode(const AnfNodePtr &node);
|
bool IsRecursionCallNode(const AnfNodePtr &node);
|
||||||
|
bool IsNeedStackControlNode(const AnfNodePtr &node);
|
||||||
// If there is a recursive call node in the input of the kernel graph, the graph is recursive.
|
// If there is a recursive call node in the input of the kernel graph, the graph is recursive.
|
||||||
bool IsRecursionKernelGraph(const KernelGraphPtr &graph);
|
bool IsRecursionKernelGraph(const KernelGraphPtr &graph);
|
||||||
bool IsSameKernelGraphGroup(const AnfNodePtr &node, const KernelGraphPtr &graph);
|
bool IsSameKernelGraphGroup(const AnfNodePtr &node, const KernelGraphPtr &graph);
|
||||||
|
@ -227,9 +228,12 @@ class ControlNodeParser {
|
||||||
// When a control node or kernel graph has input that is a call node, you need to add a stack actor for it.
|
// When a control node or kernel graph has input that is a call node, you need to add a stack actor for it.
|
||||||
void ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes);
|
void ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes);
|
||||||
bool IsCallNodeNeedStack(const AnfNodePtr &node);
|
bool IsCallNodeNeedStack(const AnfNodePtr &node);
|
||||||
void ParseNeedStackKernelGraph(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts);
|
void ParseKernelGraphGroup(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts);
|
||||||
// Parse the level of inputs and outputs of graphs and all control nodes.
|
// Parse the level of inputs and outputs of graphs and all control nodes.
|
||||||
void ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes);
|
void ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes);
|
||||||
|
// Get the level of the control node, recursively traverse all the inputs of the node, and find the largest level
|
||||||
|
// among them.
|
||||||
|
size_t ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes);
|
||||||
// When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device
|
// When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device
|
||||||
// tensor needs to be created for it.
|
// tensor needs to be created for it.
|
||||||
void CreateDeviceTensorForRootGraphParameter(DeviceContext *const default_context);
|
void CreateDeviceTensorForRootGraphParameter(DeviceContext *const default_context);
|
||||||
|
|
|
@ -525,7 +525,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
||||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||||
for (auto &switch_actor : control_actor_set->switch_actors_) {
|
for (auto &switch_actor : control_actor_set->switch_actors_) {
|
||||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||||
if (parser->need_stack_control_nodes_.find(switch_actor->node_) == parser->need_stack_control_nodes_.end()) {
|
if (!parser->IsNeedStackControlNode(switch_actor->node_)) {
|
||||||
for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
|
for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
|
||||||
LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
|
LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
|
||||||
parser);
|
parser);
|
||||||
|
@ -537,13 +537,13 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
||||||
MS_EXCEPTION_IF_NULL(actor);
|
MS_EXCEPTION_IF_NULL(actor);
|
||||||
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
||||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||||
LinkArrowFromStackActor(stack_actor, switch_actor.get());
|
LinkArrowFromStackActor(stack_actor, switch_actor.get(), parser);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &gather_actor : control_actor_set->gather_actors_) {
|
for (auto &gather_actor : control_actor_set->gather_actors_) {
|
||||||
MS_EXCEPTION_IF_NULL(gather_actor->node_);
|
MS_EXCEPTION_IF_NULL(gather_actor->node_);
|
||||||
if (parser->need_stack_control_nodes_.find(gather_actor->node_) == parser->need_stack_control_nodes_.end()) {
|
if (!parser->IsNeedStackControlNode(gather_actor->node_)) {
|
||||||
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
|
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
|
||||||
LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
|
LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
|
||||||
parser);
|
parser);
|
||||||
|
@ -555,7 +555,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
||||||
MS_EXCEPTION_IF_NULL(actor);
|
MS_EXCEPTION_IF_NULL(actor);
|
||||||
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
||||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||||
LinkArrowFromStackActor(stack_actor, gather_actor.get());
|
LinkArrowFromStackActor(stack_actor, gather_actor.get(), parser);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -567,20 +567,19 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
||||||
|
|
||||||
for (auto &exit_actor : control_actor_set->exit_actors_) {
|
for (auto &exit_actor : control_actor_set->exit_actors_) {
|
||||||
MS_EXCEPTION_IF_NULL(exit_actor);
|
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||||
if (exit_actor->node_ == nullptr ||
|
|
||||||
(parser->need_stack_control_nodes_.find(exit_actor->node_) == parser->need_stack_control_nodes_.end())) {
|
auto stack_actor_name = (exit_actor->node_ == nullptr ? GetStackActorNameByExitName(exit_actor->GetAID().Name())
|
||||||
|
: GetActorName(exit_actor->node_) + kStackActorNameSuffix);
|
||||||
|
auto actor = FetchActor(stack_actor_name);
|
||||||
|
if (actor == nullptr) {
|
||||||
for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
|
for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
|
||||||
LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i}, parser);
|
LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i}, parser);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
|
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
|
||||||
auto stack_actor_name = (exit_actor->node_ == nullptr ? GetStackActorNameByExitName(exit_actor->GetAID().Name())
|
|
||||||
: GetActorName(exit_actor->node_) + kStackActorNameSuffix);
|
|
||||||
auto actor = FetchActor(stack_actor_name);
|
|
||||||
MS_EXCEPTION_IF_NULL(actor);
|
|
||||||
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
||||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||||
LinkArrowFromStackActor(stack_actor, exit_actor.get());
|
LinkArrowFromStackActor(stack_actor, exit_actor.get(), parser);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -592,7 +591,8 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor, ControlActor *const to_actor) {
|
void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor, ControlActor *const to_actor,
|
||||||
|
const ControlNodeParserPtr &parser) {
|
||||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||||
MS_EXCEPTION_IF_NULL(to_actor);
|
MS_EXCEPTION_IF_NULL(to_actor);
|
||||||
|
|
||||||
|
@ -605,6 +605,14 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the arrow type of input.
|
// Fetch the arrow type of input.
|
||||||
|
if (to_actor->type_ == KernelTransformType::kExitActor && to_actor->node_ == nullptr && from_node->isa<CNode>() &&
|
||||||
|
(!AnfAlgo::IsCallNode(from_node)) && (!AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) &&
|
||||||
|
to_actor->GetAID().Name().find(
|
||||||
|
parser->FetchGroupNameByKernelGraph(parser->FetchKernelGraphByFrontNode(from_node))) != std::string::npos) {
|
||||||
|
LinkArrowByKernel(from_node, to_actor, formal_parameter, {to_actor->node_, to_index}, parser);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
|
size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
|
||||||
const auto &abstract = formal_parameter.first->abstract();
|
const auto &abstract = formal_parameter.first->abstract();
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
@ -876,7 +884,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
||||||
}
|
}
|
||||||
|
|
||||||
auto from_actor = control_actor;
|
auto from_actor = control_actor;
|
||||||
if (parser->need_stack_control_nodes_.find(node) != parser->need_stack_control_nodes_.end()) {
|
if (parser->IsNeedStackControlNode(node)) {
|
||||||
const auto &stack_actor_name = GetActorName(node) + kStackActorNameSuffix;
|
const auto &stack_actor_name = GetActorName(node) + kStackActorNameSuffix;
|
||||||
auto actor = FetchActor(stack_actor_name);
|
auto actor = FetchActor(stack_actor_name);
|
||||||
MS_EXCEPTION_IF_NULL(actor);
|
MS_EXCEPTION_IF_NULL(actor);
|
||||||
|
@ -1116,8 +1124,8 @@ void ControlNodeScheduler::LinkControlArrowByAutoMonad(ControlActor *to_actor, c
|
||||||
(void)from_actors.emplace_back(from_actor);
|
(void)from_actors.emplace_back(from_actor);
|
||||||
LinkControlArrow(from_actor, to_actor);
|
LinkControlArrow(from_actor, to_actor);
|
||||||
}
|
}
|
||||||
if (to_actor->type_ != KernelTransformType::kStackActor || parser->IsRecursionCallNode(depend_node) ||
|
if (to_actor->type_ != KernelTransformType::kStackActor || parser->IsNeedStackControlNode(depend_node) ||
|
||||||
(graph != nullptr && parser->IsRecursionKernelGraph(graph))) {
|
parser->IsRecursionCallNode(depend_node) || (graph != nullptr && parser->IsRecursionKernelGraph(graph))) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// If the control arrow comes from a recursive call node or a recursive kernel graph, these control edges will be
|
// If the control arrow comes from a recursive call node or a recursive kernel graph, these control edges will be
|
||||||
|
|
|
@ -76,7 +76,7 @@ class ControlNodeScheduler {
|
||||||
void LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor, size_t from_index,
|
void LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor, size_t from_index,
|
||||||
size_t to_index);
|
size_t to_index);
|
||||||
// Link arrow from stack actor to control actor.
|
// Link arrow from stack actor to control actor.
|
||||||
void LinkArrowFromStackActor(StackActor *stack_actor, ControlActor *to_actor);
|
void LinkArrowFromStackActor(StackActor *stack_actor, ControlActor *to_actor, const ControlNodeParserPtr &parser);
|
||||||
|
|
||||||
// Link data arrow between control actor and actor in frame, including kernel actor, output actor, data source actor.
|
// Link data arrow between control actor and actor in frame, including kernel actor, output actor, data source actor.
|
||||||
void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info);
|
void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info);
|
||||||
|
|
|
@ -60,7 +60,6 @@ class BackwardNet(nn.Cell):
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.platform_x86_cpu
|
|
||||||
@pytest.mark.platform_arm_ascend_training
|
@pytest.mark.platform_arm_ascend_training
|
||||||
@pytest.mark.platform_x86_ascend_training
|
@pytest.mark.platform_x86_ascend_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
|
|
Loading…
Reference in New Issue