From feb1c36811ce59a41f8cabf359468f76ad21fbe2 Mon Sep 17 00:00:00 2001 From: panyifeng Date: Tue, 31 Mar 2020 14:55:29 +0800 Subject: [PATCH] fix parallel related valuenode merging error --- mindspore/ccsrc/parallel/step_parallel.cc | 1 + mindspore/ccsrc/pipeline/action.cc | 28 +++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 927acea7055..65e5cb976a8 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -112,6 +112,7 @@ void InsertNode(const Operator& op, const CNodePtr& node, size_t index, const An MS_EXCEPTION_IF_NULL(new_node_value); PrimitivePtr new_node_prim = new_node_value->value()->cast(); new_node_prim->set_instance_name(instance_name); + new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); new_node->set_scope(scope); node_input[0]->set_scope(scope); manager->SetEdge(node, SizeToInt(index), new_node); diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index f3742ab6541..392602f419b 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -276,6 +276,31 @@ bool ExecuteAction(const ResourcePtr& res) { return true; } +// The parallel primitive related valuenode might be partitioned so that its value changes by device, +// that will result in a syncronization error due to different executing order. +// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, +// the final solution will be proposed later as a parallel feature. +bool KeepValueNodeDuplication(const AnfNodePtr& value_node, const ResourcePtr& res) { + auto& node_users = res->manager()->node_users(); + auto& users = node_users[value_node]; + auto used_by_keep_value_prim = + std::any_of(users.begin(), users.end(), [](const std::pair& user) -> bool { + MS_EXCEPTION_IF_NULL(user.first); + auto cnode = user.first->cast(); + if (cnode == nullptr) { + return false; + } + auto prim_node = cnode->input(0); + if (IsValueNode(prim_node)) { + auto prim = GetValue(prim_node->cast()->value()); + // value_node is referenced by some parallel primitive + return prim->HasAttr("keep_value_node_input"); + } + return false; + }); + return used_by_keep_value_prim; +} + bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "Remove value node duplications error."; @@ -287,6 +312,9 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { HashCache hash_cache; HashValue hashes; for (const auto& value_pair : value_nodes) { + if (KeepValueNodeDuplication(value_pair.first, res)) { + continue; + } TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes); } return true;