Parse kernel level by kernel graph.

This commit is contained in:
gaoyong10 2022-02-07 15:58:49 +08:00
parent b85161cb94
commit 911f664b0e
6 changed files with 102 additions and 54 deletions

View File

@ -358,6 +358,10 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
if (abstract != nullptr) {
(*abstract) = sub_abstracts[GetTupleGetItemOutIndex(cnode)];
MS_EXCEPTION_IF_NULL((*abstract));
} else {
// In recursion of getitem node, the index of the first input of its real node is returned.
// When the recursion ends, the outermost index needs to be accumulated.
real_index += index;
}
return {item_with_index_tmp.first, real_index};
}

View File

@ -407,7 +407,7 @@ void FetchAllExecutionFunction(const FuncGraphPtr &func_graph, std::set<FuncGrap
}
}
bool isValidMonadNode(const AnfNodePtr &node) {
bool IsValidMonadNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return node->isa<ValueNode>() || node->isa<Parameter>() || AnfAlgo::IsCallNode(node);
}
@ -419,7 +419,7 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
const auto &real_node = real_node_with_index.first;
MS_EXCEPTION_IF_NULL(real_node);
if (isValidMonadNode(real_node)) {
if (IsValidMonadNode(real_node)) {
return {real_node_with_index};
}
MS_LOG(EXCEPTION) << "Invalid monad node:" << real_node->DebugString();
@ -561,33 +561,6 @@ bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_no
}
return true;
}
// Get the level of the control node, recursively traverse all the inputs of the node, and find the largest level
// among them.
size_t ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes,
const mindspore::HashMap<AnfNodePtr, size_t> &node_to_level) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(checked_nodes);
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
return 0;
}
(void)checked_nodes->emplace(node);
auto iter = node_to_level.find(node);
if (iter != node_to_level.end()) {
return iter->second;
}
size_t level = 0;
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
for (const auto &input : inputs) {
size_t tmp_level = ParseControlNodeLevel(input, checked_nodes, node_to_level);
level = (tmp_level > level ? tmp_level : level);
}
return level;
}
} // namespace
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
@ -755,7 +728,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
ParseUnRecursionCallNode();
ParseNeedStackKernelGraph(kernel_graph_to_device_contexts);
ParseKernelGraphGroup(kernel_graph_to_device_contexts);
ParseNodeLevel(control_nodes);
@ -841,14 +814,21 @@ bool ControlNodeParser::IsRootGraphPersistentDeviceTensor(const AnfNodePtr &node
return find(root_graph_parameters_.begin(), root_graph_parameters_.end(), node) != root_graph_parameters_.end();
}
bool ControlNodeParser::IsNeedStackControlNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!(node->isa<CNode>())) {
return false;
}
return need_stack_control_nodes_.find(node) != need_stack_control_nodes_.end();
}
bool ControlNodeParser::IsRecursionCallNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsCallNode(node)) {
return false;
}
return (find(unrecursion_call_nodes_.begin(), unrecursion_call_nodes_.end(), node) ==
unrecursion_call_nodes_.end()) ||
(need_stack_control_nodes_.find(node) != need_stack_control_nodes_.end());
return find(unrecursion_call_nodes_.begin(), unrecursion_call_nodes_.end(), node) == unrecursion_call_nodes_.end();
}
bool ControlNodeParser::IsRecursionKernelGraph(const KernelGraphPtr &graph) {
@ -1852,6 +1832,16 @@ void CollectEffectiveOutputByGraph(const KernelGraphPtr &graph, DeviceContext *c
backend_to_front.second.first->isa<ValueNode>()) {
continue;
}
// Skip the function input.
const auto &abstract = backend_to_front.second.first->abstract();
MS_EXCEPTION_IF_NULL(abstract);
const auto &real_abstract = FetchAbstractByIndex(abstract, backend_to_front.second.second);
MS_EXCEPTION_IF_NULL(real_abstract);
if (real_abstract->isa<abstract::AbstractFunction>()) {
continue;
}
MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString()
<< " add front output node:" << backend_to_front.second.first->DebugString()
<< " index:" << backend_to_front.second.second
@ -1861,7 +1851,7 @@ void CollectEffectiveOutputByGraph(const KernelGraphPtr &graph, DeviceContext *c
}
}
void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts) {
void ControlNodeParser::ParseKernelGraphGroup(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts) {
for (const auto &func_graph_to_kernel_graph_groups : func_graph_to_kernel_graph_groups_) {
for (const auto &kernel_graph_group : func_graph_to_kernel_graph_groups.second) {
if (kernel_graph_group.empty()) {
@ -1903,6 +1893,49 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
}
}
size_t ControlNodeParser::ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(checked_nodes);
if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
return 0;
}
(void)checked_nodes->emplace(node);
auto iter = node_to_level_.find(node);
if (iter != node_to_level_.end()) {
return iter->second;
}
size_t level = 0;
const auto &kernel_graph = FetchKernelGraphByFrontNode(node);
if (kernel_graph == nullptr) {
// If the kernel graph is not found, it means that the input does not come from the kernel graph, then
// just continue to traverse the input.
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
for (const auto &input : inputs) {
size_t tmp_level = ParseControlNodeLevel(input, checked_nodes);
level = (tmp_level > level ? tmp_level : level);
}
return level;
}
// If the input comes from the kernel graph, you need to check all the graph's input, not just the node's input.
auto group_info_iter = kernel_graphs_to_group_info_.find(kernel_graph);
if (group_info_iter == kernel_graphs_to_group_info_.end()) {
MS_LOG(EXCEPTION) << "Failed to get kernel graph group info for graph:" << kernel_graph->ToString();
}
const auto &inputs = group_info_iter->second->front_input_nodes_;
for (const auto &input : inputs) {
const auto &node = input.first.first;
size_t tmp_level = ParseControlNodeLevel(node, checked_nodes);
level = (tmp_level > level ? tmp_level : level);
}
return level;
}
void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes) {
size_t level = 0;
// 1. Parse levels of control nodes.
@ -1917,12 +1950,12 @@ void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_no
node_to_level_[parameter] = level;
}
continue;
} else if (AnfAlgo::IsCallNode(control_node) && IsRecursionCallNode(control_node)) {
} else if (IsRecursionCallNode(control_node)) {
++level;
node_to_level_[control_node] = level;
} else {
std::set<AnfNodePtr> checked_nodes;
node_to_level_[control_node] = ParseControlNodeLevel(control_node, &checked_nodes, node_to_level_);
node_to_level_[control_node] = ParseControlNodeLevel(control_node, &checked_nodes);
}
}

View File

@ -134,6 +134,7 @@ class ControlNodeParser {
// persistent and real parameters passed.
bool IsRootGraphPersistentDeviceTensor(const AnfNodePtr &node);
bool IsRecursionCallNode(const AnfNodePtr &node);
bool IsNeedStackControlNode(const AnfNodePtr &node);
// If there is a recursive call node in the input of the kernel graph, the graph is recursive.
bool IsRecursionKernelGraph(const KernelGraphPtr &graph);
bool IsSameKernelGraphGroup(const AnfNodePtr &node, const KernelGraphPtr &graph);
@ -227,9 +228,12 @@ class ControlNodeParser {
// When a control node or kernel graph has input that is a call node, you need to add a stack actor for it.
void ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes);
bool IsCallNodeNeedStack(const AnfNodePtr &node);
void ParseNeedStackKernelGraph(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts);
void ParseKernelGraphGroup(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts);
// Parse the level of inputs and outputs of graphs and all control nodes.
void ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes);
// Get the level of the control node, recursively traverse all the inputs of the node, and find the largest level
// among them.
size_t ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes);
// When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device
// tensor needs to be created for it.
void CreateDeviceTensorForRootGraphParameter(DeviceContext *const default_context);

View File

@ -525,7 +525,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
const auto &parser = graph_compiler_info.control_node_parser_;
for (auto &switch_actor : control_actor_set->switch_actors_) {
MS_EXCEPTION_IF_NULL(switch_actor);
if (parser->need_stack_control_nodes_.find(switch_actor->node_) == parser->need_stack_control_nodes_.end()) {
if (!parser->IsNeedStackControlNode(switch_actor->node_)) {
for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
parser);
@ -537,13 +537,13 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
MS_EXCEPTION_IF_NULL(actor);
auto stack_actor = dynamic_cast<StackActor *>(actor);
MS_EXCEPTION_IF_NULL(stack_actor);
LinkArrowFromStackActor(stack_actor, switch_actor.get());
LinkArrowFromStackActor(stack_actor, switch_actor.get(), parser);
}
}
for (auto &gather_actor : control_actor_set->gather_actors_) {
MS_EXCEPTION_IF_NULL(gather_actor->node_);
if (parser->need_stack_control_nodes_.find(gather_actor->node_) == parser->need_stack_control_nodes_.end()) {
if (!parser->IsNeedStackControlNode(gather_actor->node_)) {
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
parser);
@ -555,7 +555,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
MS_EXCEPTION_IF_NULL(actor);
auto stack_actor = dynamic_cast<StackActor *>(actor);
MS_EXCEPTION_IF_NULL(stack_actor);
LinkArrowFromStackActor(stack_actor, gather_actor.get());
LinkArrowFromStackActor(stack_actor, gather_actor.get(), parser);
}
}
@ -567,20 +567,19 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
for (auto &exit_actor : control_actor_set->exit_actors_) {
MS_EXCEPTION_IF_NULL(exit_actor);
if (exit_actor->node_ == nullptr ||
(parser->need_stack_control_nodes_.find(exit_actor->node_) == parser->need_stack_control_nodes_.end())) {
auto stack_actor_name = (exit_actor->node_ == nullptr ? GetStackActorNameByExitName(exit_actor->GetAID().Name())
: GetActorName(exit_actor->node_) + kStackActorNameSuffix);
auto actor = FetchActor(stack_actor_name);
if (actor == nullptr) {
for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i}, parser);
}
} else {
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
auto stack_actor_name = (exit_actor->node_ == nullptr ? GetStackActorNameByExitName(exit_actor->GetAID().Name())
: GetActorName(exit_actor->node_) + kStackActorNameSuffix);
auto actor = FetchActor(stack_actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto stack_actor = dynamic_cast<StackActor *>(actor);
MS_EXCEPTION_IF_NULL(stack_actor);
LinkArrowFromStackActor(stack_actor, exit_actor.get());
LinkArrowFromStackActor(stack_actor, exit_actor.get(), parser);
}
}
@ -592,7 +591,8 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
}
}
void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor, ControlActor *const to_actor) {
void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor, ControlActor *const to_actor,
const ControlNodeParserPtr &parser) {
MS_EXCEPTION_IF_NULL(stack_actor);
MS_EXCEPTION_IF_NULL(to_actor);
@ -605,6 +605,14 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor
}
// Fetch the arrow type of input.
if (to_actor->type_ == KernelTransformType::kExitActor && to_actor->node_ == nullptr && from_node->isa<CNode>() &&
(!AnfAlgo::IsCallNode(from_node)) && (!AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) &&
to_actor->GetAID().Name().find(
parser->FetchGroupNameByKernelGraph(parser->FetchKernelGraphByFrontNode(from_node))) != std::string::npos) {
LinkArrowByKernel(from_node, to_actor, formal_parameter, {to_actor->node_, to_index}, parser);
continue;
}
size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
const auto &abstract = formal_parameter.first->abstract();
MS_EXCEPTION_IF_NULL(abstract);
@ -876,7 +884,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
}
auto from_actor = control_actor;
if (parser->need_stack_control_nodes_.find(node) != parser->need_stack_control_nodes_.end()) {
if (parser->IsNeedStackControlNode(node)) {
const auto &stack_actor_name = GetActorName(node) + kStackActorNameSuffix;
auto actor = FetchActor(stack_actor_name);
MS_EXCEPTION_IF_NULL(actor);
@ -1116,8 +1124,8 @@ void ControlNodeScheduler::LinkControlArrowByAutoMonad(ControlActor *to_actor, c
(void)from_actors.emplace_back(from_actor);
LinkControlArrow(from_actor, to_actor);
}
if (to_actor->type_ != KernelTransformType::kStackActor || parser->IsRecursionCallNode(depend_node) ||
(graph != nullptr && parser->IsRecursionKernelGraph(graph))) {
if (to_actor->type_ != KernelTransformType::kStackActor || parser->IsNeedStackControlNode(depend_node) ||
parser->IsRecursionCallNode(depend_node) || (graph != nullptr && parser->IsRecursionKernelGraph(graph))) {
continue;
}
// If the control arrow comes from a recursive call node or a recursive kernel graph, these control edges will be

View File

@ -76,7 +76,7 @@ class ControlNodeScheduler {
void LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor, size_t from_index,
size_t to_index);
// Link arrow from stack actor to control actor.
void LinkArrowFromStackActor(StackActor *stack_actor, ControlActor *to_actor);
void LinkArrowFromStackActor(StackActor *stack_actor, ControlActor *to_actor, const ControlNodeParserPtr &parser);
// Link data arrow between control actor and actor in frame, including kernel actor, output actor, data source actor.
void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info);

View File

@ -60,7 +60,6 @@ class BackwardNet(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard