From 1683946951c48696a6df11d51e7d9086f0427f03 Mon Sep 17 00:00:00 2001 From: chenfei Date: Fri, 19 Aug 2022 16:17:37 +0800 Subject: [PATCH] fix bug of ascend when 2 parameter's arg are same --- .../backend/common/session/kernel_graph.cc | 9 ++++- .../construct_input/test_outermost_input.py | 35 ++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/common/session/kernel_graph.cc b/mindspore/ccsrc/backend/common/session/kernel_graph.cc index 5091ebb0a26..07bb3879c08 100644 --- a/mindspore/ccsrc/backend/common/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/common/session/kernel_graph.cc @@ -770,7 +770,15 @@ void KernelGraph::FrontBackendMapAdd(const AnfNodePtr &front_anf, const AnfNodeP if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) { MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_"; } + front_backend_anf_map_[front_anf] = backend_anf; if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) { + // If def func(x, y) and call as func(arg, arg) ,then the parameter x and y share same param_info "arg". + // In this case, parameter is get from param_info and has been exist in the map. So can't add it to map again. + if (backend_anf->isa()) { + MS_LOG(INFO) << "Backend parameter already exist, backend parameter:" << backend_anf->DebugString() + << ", exist front parameter:" << backend_front_anf_map_[backend_anf]->DebugString(); + return; + } auto front_node = front_anf->cast(); MS_EXCEPTION_IF_NULL(front_node); auto attr_input = front_node->input(kAnfPrimitiveIndex); @@ -779,7 +787,6 @@ void KernelGraph::FrontBackendMapAdd(const AnfNodePtr &front_anf, const AnfNodeP MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_"; } } - front_backend_anf_map_[front_anf] = backend_anf; backend_front_anf_map_[backend_anf] = front_anf; } diff --git a/tests/st/construct_input/test_outermost_input.py b/tests/st/construct_input/test_outermost_input.py index 208c090473d..dfc4393b0bc 100644 --- a/tests/st/construct_input/test_outermost_input.py +++ b/tests/st/construct_input/test_outermost_input.py @@ -120,6 +120,7 @@ def test_grad_first_input_net(mode): Description: Normal input type. Expectation: No exception. """ + class FirstInputTensorNet(nn.Cell): def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c): return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"] @@ -152,6 +153,17 @@ class GradCellWithParameter(nn.Cell): return self.grad(self.net, self.param)(x) +class AssignParameterWithCell(nn.Cell): + def __init__(self, net): + super().__init__() + self.net = net + self.param = self.net.param + + def construct(self, x): + self.param = self.param * 2 + return x + + class GradCell(nn.Cell): def __init__(self, net): super().__init__() @@ -215,9 +227,10 @@ def test_grad_parameter_as_input_and_fv(mode): # PyNative run error. # Support context.PYNATIVE_MODE later. -# Support Ascend BE later. @pytest.mark.level0 @pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @pytest.mark.parametrize('mode', [context.GRAPH_MODE]) @@ -238,6 +251,26 @@ def test_grad_same_parameter_both_input_and_fv(mode): assert np.array_equal(a[1].asnumpy(), b[1].asnumpy()) +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) +def test_same_arg_parameter_assign(mode): + """ + Feature: Change value of common parameter. + Description: Change value of common parameter.Watch another parameter with same param_obj. + Expectation: No exception. + """ + context.set_context(mode=mode) + x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x') + a = AssignParameterWithCell(TestCell(x))(x) + print(f'a: {a}') + assert np.array_equal(a.asnumpy(), x.asnumpy()) + + class TestCell2(nn.Cell): def __init__(self, param1, param2): super().__init__()