fix the bug of actor runtime auto monad

This commit is contained in:
limingqi107 2021-06-25 20:23:17 +08:00
parent c986916d48
commit 616bc83acf
1 changed files with 7 additions and 1 deletions

View File

@ -711,6 +711,8 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
GraphExecutionStrategy strategy) { GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(actor_set); MS_EXCEPTION_IF_NULL(actor_set);
std::vector<KernelActor *> auto_monad_actors; std::vector<KernelActor *> auto_monad_actors;
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
// Foreach the execution order to link the actors. // Foreach the execution order to link the actors.
for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) { for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
@ -727,7 +729,9 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) { for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
auto input_node = AnfAlgo::GetInputNode(kernel, i); auto input_node = AnfAlgo::GetInputNode(kernel, i);
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node. // Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
LinkControlArrowByAutoMonad(kernel_actor, input_node); if (AnfAlgo::IsOneOfPrimitiveCNode(input_node, auto_monad_prims)) {
LinkControlArrowByAutoMonad(kernel_actor, input_node);
}
if (HasAbstractMonad(input_node)) { if (HasAbstractMonad(input_node)) {
auto_monad_actors.emplace_back(kernel_actor); auto_monad_actors.emplace_back(kernel_actor);
continue; // No data arrow for monad input. continue; // No data arrow for monad input.
@ -1360,6 +1364,8 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
} }
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) { } else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) {
real_depend_inputs.push_back(input_cnode->input(kLoadStateInput)); real_depend_inputs.push_back(input_cnode->input(kLoadStateInput));
} else {
real_depend_inputs.push_back(input_cnode);
} }
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = { const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {