forked from mindspore-Ecosystem/mindspore
fix the bug of actor runtime auto monad
This commit is contained in:
parent
c986916d48
commit
616bc83acf
|
@ -711,6 +711,8 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
|
||||||
GraphExecutionStrategy strategy) {
|
GraphExecutionStrategy strategy) {
|
||||||
MS_EXCEPTION_IF_NULL(actor_set);
|
MS_EXCEPTION_IF_NULL(actor_set);
|
||||||
std::vector<KernelActor *> auto_monad_actors;
|
std::vector<KernelActor *> auto_monad_actors;
|
||||||
|
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
|
||||||
|
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
|
||||||
|
|
||||||
// Foreach the execution order to link the actors.
|
// Foreach the execution order to link the actors.
|
||||||
for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
|
for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
|
||||||
|
@ -727,7 +729,9 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
|
||||||
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
|
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
|
||||||
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
||||||
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
|
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
|
||||||
LinkControlArrowByAutoMonad(kernel_actor, input_node);
|
if (AnfAlgo::IsOneOfPrimitiveCNode(input_node, auto_monad_prims)) {
|
||||||
|
LinkControlArrowByAutoMonad(kernel_actor, input_node);
|
||||||
|
}
|
||||||
if (HasAbstractMonad(input_node)) {
|
if (HasAbstractMonad(input_node)) {
|
||||||
auto_monad_actors.emplace_back(kernel_actor);
|
auto_monad_actors.emplace_back(kernel_actor);
|
||||||
continue; // No data arrow for monad input.
|
continue; // No data arrow for monad input.
|
||||||
|
@ -1360,6 +1364,8 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
|
||||||
}
|
}
|
||||||
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) {
|
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) {
|
||||||
real_depend_inputs.push_back(input_cnode->input(kLoadStateInput));
|
real_depend_inputs.push_back(input_cnode->input(kLoadStateInput));
|
||||||
|
} else {
|
||||||
|
real_depend_inputs.push_back(input_cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
|
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
|
||||||
|
|
Loading…
Reference in New Issue