1. Graph scheduler only link monad control arrows in the same kernel graph.

2. Fix ref count for local tensor.
This commit is contained in:
gaoyong10 2021-12-14 11:53:06 +08:00
parent 944bdacd92
commit 7c061392e2
4 changed files with 17 additions and 6 deletions

View File

@ -1090,7 +1090,7 @@ void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail。"
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set output index:" << output_idx << " fail。"
<< " trace: " << trace::DumpSourceLines(node);
}
}

View File

@ -15,6 +15,7 @@
*/
#include "runtime/framework/control_node_parser.h"
#include "runtime/framework/actor/actor_common.h"
#include "abstract/utils.h"
#include "ir/tensor.h"
@ -411,6 +412,7 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index
MS_EXCEPTION_IF_NULL(address);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address;
AnfAlgo::SetOutputAddr(address, front_node_with_index.second, front_node.get());
UpdateRefCount(address.get(), true);
}
// Create a device tensor for front node.
@ -437,6 +439,7 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
MS_EXCEPTION_IF_NULL(address);
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(node) << " addr:" << address;
AnfAlgo::SetOutputAddr(address, front_node_with_index.second, node.get());
UpdateRefCount(address.get(), true);
}
// Fetch all funcgraph by a seed graph, if a calls b, b calls c, and c calls a, return a set of a, b, c.

View File

@ -944,7 +944,7 @@ void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
auto input_node = AnfAlgo::GetInputNode(kernel, i);
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
if (IsOneOfPrimitiveCNode(input_node, auto_monad_prims) || HasAbstractMonad(input_node)) {
LinkControlArrowByAutoMonad(kernel_actor, input_node, graph);
LinkControlArrowByAutoMonad(kernel_actor, input_node, graph, graph_compiler_info.control_node_parser_);
}
if (HasAbstractMonad(input_node)) {
(void)auto_monad_actors->emplace_back(kernel_actor);
@ -1181,7 +1181,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
}
void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node,
const KernelGraphPtr &graph) {
const KernelGraphPtr &graph, const ControlNodeParserPtr &parser) {
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(from_node);
MS_EXCEPTION_IF_NULL(graph);
@ -1199,7 +1199,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const
if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
MS_EXCEPTION_IF_NULL(input_cnode);
for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph);
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph, parser);
}
return;
}
@ -1240,6 +1240,13 @@ void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const
}
MS_LOG(EXCEPTION) << "Can't find graph output by front node:" << front_output_with_index.first->DebugString();
}
if (parser != nullptr && parser->IsInited() &&
(!parser->IsSameKernelGraphGroup(front_output_with_index.first, graph))) {
MS_LOG(DEBUG) << "Skip in control flow from node:" << front_output_with_index.first->DebugString()
<< " is not in the graph:" << graph->ToString();
continue;
}
real_depend_kernel = graph_output_to_actor_[front_output_with_index].second.first;
MS_EXCEPTION_IF_NULL(real_depend_kernel);
MS_LOG(INFO) << "The graph " << graph->graph_id() << " link control arrow by auto monad from internal parameter: "
@ -1257,7 +1264,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const
// The monad node and make tuple node need recursion.
if (IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph);
LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph, parser);
continue;
}

View File

@ -141,7 +141,8 @@ class GraphScheduler {
const KernelWithIndex &to_kernel_with_input_idx);
// 2. The processing of linking control arrows.
void LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph);
void LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph,
const ControlNodeParserPtr &parser = nullptr);
// The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
void LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node);
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.