fix the scope setting error when cloning nodes

This commit is contained in:
Xiaoda Zhang 2021-09-06 19:36:36 +08:00
parent 0930a1617c
commit 352f40d750
5 changed files with 15 additions and 19 deletions

View File

@ -98,7 +98,6 @@ class TwoReshapeEliminater : public AnfVisitor {
if (node->scope() != kDefaultScope) {
new_node->set_scope(node->scope());
}
new_node->set_fullname_with_scope(node->fullname_with_scope());
return new_node;
}
return nullptr;

View File

@ -77,7 +77,7 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target,
// Default parameter can be shared since it is readonly.
new_param->set_default_param(old_param->default_param());
}
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
new_param->set_scope(scope);
repl_node_[node] = new_param;
}
@ -89,11 +89,8 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
auto old_node = node->cast<CNodePtr>();
new_node->CloneCNodeInfo(old_node);
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
new_node->set_scope(scope);
if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) {
new_node->set_fullname_with_scope(old_node->fullname_with_scope());
}
repl_node_[old_node] = new_node;
nodes_.emplace_back(old_node, new_node);
}
@ -102,7 +99,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
TraceGuard trace_guard(node->debug_info(), relation_);
ValueNodePtr new_const = NewValueNode(GetValueNode(node));
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
new_const->set_scope(scope);
new_const->set_abstract(node->abstract());
new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
@ -114,7 +111,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target)
MS_EXCEPTION_IF_NULL(target);
TraceGuard trace_guard(node->debug_info(), relation_);
ValueNodePtr new_const = NewValueNode(target);
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
new_const->set_scope(scope);
new_const->set_abstract(node->abstract());
new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());

View File

@ -85,13 +85,13 @@ def run_e2e_dump():
add(Tensor(x), Tensor(y))
if context.get_context("device_target") == "Ascend":
assert len(os.listdir(dump_file_path)) == 5
output_name = "Add.Add-op1.0.0.*.output.0.DefaultFormat.npy"
output_name = "Add.Add-op*.0.0.*.output.0.DefaultFormat.npy"
elif context.get_context("device_target") == "CPU":
assert len(os.listdir(dump_file_path)) == 5
output_name = "Add.Add-op3.0.0.*.output.0.DefaultFormat.npy"
output_name = "Add.Add-op*.0.0.*.output.0.DefaultFormat.npy"
else:
assert len(os.listdir(dump_file_path)) == 3
output_name = "Add.Add-op3.0.0.*.output.0.DefaultFormat.npy"
output_name = "Add.Add-op*.0.0.*.output.0.DefaultFormat.npy"
output_path = glob.glob(os.path.join(dump_file_path, output_name))[0]
real_path = os.path.realpath(output_path)
output = np.load(real_path)

View File

@ -84,9 +84,9 @@ def test_double_star_graph():
net.set_train()
_executor.compile(net, x, y, z, w, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/Cast-op2': [[8, 1]],
'Default/network-Net/Cast-op4': [[1, 8]],
'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op5': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op1': [[1, 8], [8, 1]]}
expected_strategies = {'Default/network-Net/Cast-op1': [[8, 1]],
'Default/network-Net/Cast-op3': [[1, 8]],
'Default/network-Net/MatMul-op2': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op0': [[1, 8], [8, 1]]}
assert strategies == expected_strategies

View File

@ -79,8 +79,8 @@ def test_two_matmul_transpose():
net.set_train()
_executor.compile(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/Transpose-op1': [[1, 16]],
'Default/network-Net/Transpose-op2': [[16, 1]],
expected_strategies = {'Default/network-Net/Transpose-op0': [[1, 16]],
'Default/network-Net/Transpose-op1': [[16, 1]],
'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]],
'Default/network-Net/MatMul-op4': [[16, 1], [1, 1]]}
'Default/network-Net/MatMul-op2': [[16, 1], [1, 1]]}
assert strategies == expected_strategies