From 78fc2e45ffaebe9545ff5d74cc5ff4223f58bd76 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Thu, 14 Oct 2021 09:18:39 +0800 Subject: [PATCH] Fix the resolve problem with parameters with the same name --- mindspore/ccsrc/pipeline/jit/parse/parse.cc | 1 + mindspore/ccsrc/pipeline/jit/parse/resolve.cc | 2 +- mindspore/core/ir/anf.h | 4 ++ .../simple_expression/test_hyper_param.py | 47 +++++++++++++++++++ 4 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 tests/syntax/simple_expression/test_hyper_param.py diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 7cc3f4a6847..0bc06d4a606 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -2167,6 +2167,7 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) { param->set_name(name); param->debug_info()->set_name(name); param->debug_info()->set_location(param->debug_info()->location()); + param->set_is_top_graph_param(true); } func_graph->set_has_vararg(current_graph->has_vararg()); func_graph->set_has_kwarg(current_graph->has_kwarg()); diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 46b08a7d4fd..83f212f9115 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -113,7 +113,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object AnfNodePtr para_node = nullptr; for (auto const ¶m : top_func_graph->parameters()) { auto param_node = dyn_cast(param); - if (param_node != nullptr && param_node->name() == param_name) { + if (param_node != nullptr && param_node->name() == param_name && !param_node->is_top_graph_param()) { para_node = param; MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString() << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString(); diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 3f019575457..40e0b4a2957 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -403,6 +403,9 @@ class MS_CORE_API Parameter : public ANode { void DecreaseUsedGraphCount() { used_graph_count_--; } int used_graph_count() const { return used_graph_count_; } + bool is_top_graph_param() const { return is_top_graph_param_; } + void set_is_top_graph_param(bool flag) { is_top_graph_param_ = flag; } + bool operator==(const AnfNode &other) const override { if (!other.isa()) { return false; @@ -439,6 +442,7 @@ class MS_CORE_API Parameter : public ANode { int used_graph_count_; // groups attr in FracZ format int64_t fracz_group_ = 1; + bool is_top_graph_param_ = false; }; using ParameterPtr = std::shared_ptr; diff --git a/tests/syntax/simple_expression/test_hyper_param.py b/tests/syntax/simple_expression/test_hyper_param.py new file mode 100644 index 00000000000..e6049fec531 --- /dev/null +++ b/tests/syntax/simple_expression/test_hyper_param.py @@ -0,0 +1,47 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +from mindspore import Tensor, Parameter +from mindspore.nn import Cell + + +def test_hyper_param(): + """ + Feature: Resolve parameter. + Description: The name of parameter in construct is the same with the name of parameter of class init. + Expectation: self.a is different from a in construct. + """ + class HyperParamNet(Cell): + def __init__(self): + super(HyperParamNet, self).__init__() + self.a = Parameter(Tensor(1, ms.float32), name="a") + self.b = Parameter(Tensor(5, ms.float32), name="param_b") + self.c = Parameter(Tensor(9, ms.float32), name="param_c") + + def func_inner(self, c): + return self.a + self.b + c + + def construct(self, a, b): + self.a = a + self.b = b + return self.func_inner(self.c) + + x = Tensor(11, ms.float32) + y = Tensor(19, ms.float32) + net = HyperParamNet() + output = net(x, y) + output_expect = Tensor(39, ms.float32) + assert output == output_expect