fix parallel related valuenode merging error
This commit is contained in:
parent
2753aa8768
commit
feb1c36811
|
@ -112,6 +112,7 @@ void InsertNode(const Operator& op, const CNodePtr& node, size_t index, const An
|
||||||
MS_EXCEPTION_IF_NULL(new_node_value);
|
MS_EXCEPTION_IF_NULL(new_node_value);
|
||||||
PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
|
PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
|
||||||
new_node_prim->set_instance_name(instance_name);
|
new_node_prim->set_instance_name(instance_name);
|
||||||
|
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
|
||||||
new_node->set_scope(scope);
|
new_node->set_scope(scope);
|
||||||
node_input[0]->set_scope(scope);
|
node_input[0]->set_scope(scope);
|
||||||
manager->SetEdge(node, SizeToInt(index), new_node);
|
manager->SetEdge(node, SizeToInt(index), new_node);
|
||||||
|
|
|
@ -276,6 +276,31 @@ bool ExecuteAction(const ResourcePtr& res) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The parallel primitive related valuenode might be partitioned so that its value changes by device,
|
||||||
|
// that will result in a syncronization error due to different executing order.
|
||||||
|
// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
|
||||||
|
// the final solution will be proposed later as a parallel feature.
|
||||||
|
bool KeepValueNodeDuplication(const AnfNodePtr& value_node, const ResourcePtr& res) {
|
||||||
|
auto& node_users = res->manager()->node_users();
|
||||||
|
auto& users = node_users[value_node];
|
||||||
|
auto used_by_keep_value_prim =
|
||||||
|
std::any_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int>& user) -> bool {
|
||||||
|
MS_EXCEPTION_IF_NULL(user.first);
|
||||||
|
auto cnode = user.first->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto prim_node = cnode->input(0);
|
||||||
|
if (IsValueNode<Primitive>(prim_node)) {
|
||||||
|
auto prim = GetValue<PrimitivePtr>(prim_node->cast<ValueNodePtr>()->value());
|
||||||
|
// value_node is referenced by some parallel primitive
|
||||||
|
return prim->HasAttr("keep_value_node_input");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
return used_by_keep_value_prim;
|
||||||
|
}
|
||||||
|
|
||||||
bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) {
|
bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) {
|
||||||
if (res->func_graph() == nullptr) {
|
if (res->func_graph() == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
|
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
|
||||||
|
@ -287,6 +312,9 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) {
|
||||||
HashCache hash_cache;
|
HashCache hash_cache;
|
||||||
HashValue hashes;
|
HashValue hashes;
|
||||||
for (const auto& value_pair : value_nodes) {
|
for (const auto& value_pair : value_nodes) {
|
||||||
|
if (KeepValueNodeDuplication(value_pair.first, res)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
|
TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
Loading…
Reference in New Issue