set cnode's fullname when cloning

This commit is contained in:
Xiaoda Zhang 2020-12-04 15:59:44 +08:00
parent c81bb8fe39
commit 9a9e3a751e
7 changed files with 41 additions and 24 deletions

View File

@ -95,6 +95,9 @@ class TwoReshapeEliminater : public AnfVisitor {
if (fg != nullptr && x_ != nullptr && shape_ != nullptr) {
auto new_node = fg->NewCNode({NewValueNode(prim_), x_, shape_});
new_node->set_abstract(node->abstract());
if (node->scope() != kDefaultScope) {
new_node->set_scope(node->scope());
}
new_node->set_fullname_with_scope(node->fullname_with_scope());
return new_node;
}

View File

@ -689,6 +689,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
cnode->set_user_data<OperatorInfo>(current_op_ptr);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
<< " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
}
}

View File

@ -91,6 +91,9 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
new_node->set_inputs_value(old_node->inputs_value());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope);
if (IsPrimitiveCNode(old_node, nullptr) && new_node->scope() == kDefaultScope) {
new_node->set_fullname_with_scope(old_node->fullname_with_scope());
}
new_node->set_kernel_info(old_node->kernel_info_ptr());
repl_node_[old_node] = new_node;
nodes_.emplace_back(old_node, new_node);

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
import mindspore as ms
@ -78,9 +79,11 @@ def test_auto_parallel_arithmetic():
b = Tensor(np.ones([64, 128]), dtype=ms.float32)
compile_net(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[2, 4], [2, 4]],
'Default/network-Net/MatMul-op1': [[2, 1], [1, 4]]}
assert strategies == expected_strategies
for (k, v) in strategies.items():
if re.search('FloorDiv-op', k) is not None:
assert v == [[2, 4], [2, 4]]
elif re.search('MatMul-op', k) is not None:
assert v == [[2, 1], [1, 4]]
def test_auto_parallel_arithmetic_broadcast_both():
@ -105,9 +108,11 @@ def test_auto_parallel_arithmetic_broadcast_both():
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
compile_net(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]]}
assert strategies == expected_strategies
for (k, v) in strategies.items():
if re.search('FloorDiv-op', k) is not None:
assert v == [[8, 1], [1, 1]]
elif re.search('MatMul-op', k) is not None:
assert v == [[8, 1], [1, 1]]
def test_auto_parallel_arithmetic_broadcast_right():
@ -132,9 +137,11 @@ def test_auto_parallel_arithmetic_broadcast_right():
b = Tensor(np.ones([32]), dtype=ms.float32)
compile_net(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [2]],
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}
assert strategies == expected_strategies
for (k, v) in strategies.items():
if re.search('FloorDiv-op', k) is not None:
assert v == [[4, 2], [2]]
elif re.search('MatMul-op', k) is not None:
assert v == [[4, 1], [1, 2]]
def test_auto_parallel_arithmetic_broadcast_left():
@ -159,6 +166,8 @@ def test_auto_parallel_arithmetic_broadcast_left():
b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
compile_net(net, x, y, b, phase="train")
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [1, 4, 2]],
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}
assert strategies == expected_strategies
for (k, v) in strategies.items():
if re.search('FloorDiv-op', k) is not None:
assert v == [[4, 2], [1, 4, 2]]
elif re.search('MatMul-op', k) is not None:
assert v == [[4, 1], [1, 2]]

View File

@ -84,9 +84,9 @@ def test_double_star_graph():
net.set_train()
_executor.compile(net, x, y, z, w, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/Cast-op1': [[8, 1]],
'Default/network-Net/Cast-op3': [[1, 8]],
'Default/network-Net/MatMul-op2': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op0': [[1, 8], [8, 1]]}
expected_strategies = {'Default/network-Net/Cast-op5': [[8, 1]],
'Default/network-Net/Cast-op7': [[1, 8]],
'Default/network-Net/MatMul-op6': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op8': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op4': [[1, 8], [8, 1]]}
assert strategies == expected_strategies

View File

@ -79,8 +79,8 @@ def test_two_matmul_transpose():
net.set_train()
_executor.compile(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/Transpose-op0': [[1, 16]],
'Default/network-Net/Transpose-op1': [[16, 1]],
'Default/network-Net/MatMul-op2': [[16, 1], [1, 1]],
'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]]}
expected_strategies = {'Default/network-Net/Transpose-op4': [[1, 16]],
'Default/network-Net/Transpose-op5': [[16, 1]],
'Default/network-Net/MatMul-op7': [[16, 1], [1, 1]],
'Default/network-Net/MatMul-op6': [[16, 1], [1, 1]]}
assert strategies == expected_strategies

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
import mindspore as ms
@ -155,6 +156,6 @@ def test_two_matmul():
net.set_train()
_executor.compile(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/MatMul-op0': [[16, 1], [1, 1]],
'Default/network-Net/MatMul-op1': [[16, 1], [1, 1]]}
assert strategies == expected_strategies
for (k, v) in strategies.items():
if re.search('MatMul-op', k) is not None:
assert v == [[16, 1], [1, 1]]