forked from mindspore-Ecosystem/mindspore
[ME] Add parameter name check.
This commit is contained in:
parent
33bc6978a0
commit
462c38813d
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -74,6 +74,13 @@ abstract::AbstractBasePtr ClassType::ToAbstract() {
|
|||
}
|
||||
|
||||
namespace {
|
||||
std::string GetPyObjId(const py::object &obj) {
|
||||
py::object out = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
|
||||
if (py::isinstance<py::none>(out)) {
|
||||
MS_LOG(EXCEPTION) << "Get pyobj failed";
|
||||
}
|
||||
return out.cast<std::string>();
|
||||
}
|
||||
// If any mixed precision flag add a cast node after the parameter node.
|
||||
// argument obj should be python Parameter object
|
||||
// it will be converted to Parameter node here
|
||||
|
@ -94,23 +101,40 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
|
|||
if (py::isinstance<py::none>(name_attr)) {
|
||||
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
|
||||
}
|
||||
|
||||
auto obj_id = GetPyObjId(obj);
|
||||
static std::vector<std::string> param_obj_ids;
|
||||
auto param_name = py::cast<std::string>(name_attr);
|
||||
auto top_func_graph = Parser::GetTopFuncGraph();
|
||||
// If the parameter node has been created , return it.
|
||||
AnfNodePtr para_node = nullptr;
|
||||
for (auto const ¶m : top_func_graph->parameters()) {
|
||||
auto param_node = dyn_cast<Parameter>(param);
|
||||
if (param_node != nullptr && param_node->name() == param_name && !param_node->is_top_graph_param()) {
|
||||
para_node = param;
|
||||
MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString()
|
||||
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
|
||||
break;
|
||||
if (param_node != nullptr && param_node->name() == param_name) {
|
||||
if (param_node->is_top_graph_param()) {
|
||||
// If the name of the input of construct is same as the parameters,
|
||||
// add suffix to the name of the input of construct.
|
||||
string suffix_name = param_node->name() + "_$";
|
||||
param_node->set_name(suffix_name);
|
||||
param_node->debug_info()->set_name(suffix_name);
|
||||
MS_LOG(DEBUG) << "Add suffix to the name of the input of construct " << func_graph->ToString()
|
||||
<< ", input: " << param_node->DebugString();
|
||||
} else {
|
||||
// Exist two parameter object which name is the same.
|
||||
if (std::find(param_obj_ids.begin(), param_obj_ids.end(), obj_id) == param_obj_ids.end()) {
|
||||
MS_LOG(EXCEPTION) << "The parameter " << param_node->DebugString() << " , its name '" << param_name
|
||||
<< "' already exists. Please set a unique name for the parameter.";
|
||||
}
|
||||
para_node = param;
|
||||
MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString()
|
||||
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (para_node == nullptr) {
|
||||
auto node = top_func_graph->AddWeightParameter(param_name);
|
||||
auto value = py::cast<tensor::MetaTensorPtr>(obj);
|
||||
param_obj_ids.emplace_back(obj_id);
|
||||
node->set_default_param(value);
|
||||
// Set abstract for parameter
|
||||
auto abs = value->ToAbstract();
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -658,24 +658,17 @@ class ParameterTuple(tuple):
|
|||
"""Create instance object of ParameterTuple."""
|
||||
data = tuple(iterable)
|
||||
ids = set()
|
||||
orders = {}
|
||||
names = set()
|
||||
for x in data:
|
||||
if not isinstance(x, Parameter):
|
||||
raise TypeError(f"ParameterTuple input should be `Parameter` collection."
|
||||
f"But got a {type(iterable)}, {iterable}")
|
||||
if id(x) not in ids:
|
||||
if x.name in names:
|
||||
raise ValueError("The value {} , its name '{}' already exists. "
|
||||
"Please set a unique name for the parameter.".format(x, x.name))
|
||||
names.add(x.name)
|
||||
ids.add(id(x))
|
||||
if x.name not in orders.keys():
|
||||
orders[x.name] = [0, x]
|
||||
else:
|
||||
if isinstance(orders[x.name], list):
|
||||
name = x.name
|
||||
orders[name][1].name = name + "_" + str(0)
|
||||
x.name = x.name + "_" + str(1)
|
||||
orders[name] = 1
|
||||
else:
|
||||
orders[x.name] += 1
|
||||
x.name = x.name + "_" + str(orders[x.name])
|
||||
return tuple.__new__(ParameterTuple, tuple(data))
|
||||
|
||||
def clone(self, prefix, init='same'):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -565,8 +565,8 @@ class Cell(Cell_):
|
|||
self._id += 1
|
||||
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
||||
if item.name in exist_names:
|
||||
raise ValueError("The value {} , its name '{}' already exists.".
|
||||
format(value, item.name))
|
||||
raise ValueError("The value {} , its name '{}' already exists. "
|
||||
"Please set a unique name for the parameter.".format(value, item.name))
|
||||
exist_names.add(item.name)
|
||||
|
||||
if context._get_mode() == context.PYNATIVE_MODE:
|
||||
|
@ -589,8 +589,8 @@ class Cell(Cell_):
|
|||
item.name = item.name + "$" + str(self._id)
|
||||
self._id += 1
|
||||
if item.name in self.exist_names:
|
||||
raise ValueError("The value {} , its name '{}' already exists.".
|
||||
format(value, item.name))
|
||||
raise ValueError("The value {} , its name '{}' already exists. "
|
||||
"Please set a unique name for the parameter.".format(value, item.name))
|
||||
self.exist_names.add(item.name)
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
|
@ -1114,8 +1114,8 @@ class Cell(Cell_):
|
|||
names = set("")
|
||||
for value, param in self.parameters_and_names():
|
||||
if param.name in names:
|
||||
raise ValueError("The value of {} is {}, its name '{}' already exists.".
|
||||
format(value, param, param.name))
|
||||
raise ValueError("The value of {} is {}, its name '{}' already exists. "
|
||||
"Please set a unique name for the parameter.". format(value, param, param.name))
|
||||
names.add(param.name)
|
||||
|
||||
def parameters_and_names(self, name_prefix='', expand=True):
|
||||
|
|
|
@ -732,70 +732,3 @@ def test_assign_in_zip_loop():
|
|||
net = AssignInZipLoop()
|
||||
out = net(x)
|
||||
assert np.all(out.asnumpy() == 1)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: If parameter in list or tuple is not given a name, will give it a unique name.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
|
||||
self.param_c = Parameter(Tensor([3], ms.float32))
|
||||
self.param_d = Parameter(Tensor([4], ms.float32))
|
||||
self.param_tuple = (Parameter(Tensor([5], ms.float32)),
|
||||
Parameter(Tensor([6], ms.float32)))
|
||||
self.param_list = [Parameter(Tensor([5], ms.float32)),
|
||||
Parameter(Tensor([6], ms.float32))]
|
||||
|
||||
def construct(self, x):
|
||||
res1 = self.param_a + self.param_b + self.param_c + self.param_d
|
||||
res1 = res1 - self.param_list[0] + self.param_list[1] + x
|
||||
res2 = self.param_list[0] + self.param_list[1]
|
||||
return res1, res2
|
||||
|
||||
net = ParamNet()
|
||||
x = Tensor([10], ms.float32)
|
||||
output1, output2 = net(x)
|
||||
output1_expect = Tensor(21, ms.float32)
|
||||
output2_expect = Tensor(11, ms.float32)
|
||||
assert output1 == output1_expect
|
||||
assert output2 == output2_expect
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_same_name_between_tuple_or_list():
|
||||
"""
|
||||
Feature: Check the names of parameters between tuple or list.
|
||||
Description: If the same name exists between tuple and list, an exception will be thrown.
|
||||
Expectation: Get the expected exception report.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_tuple = (Parameter(Tensor([1], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([2], ms.float32)))
|
||||
self.param_list = [Parameter(Tensor([3], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([4], ms.float32))]
|
||||
|
||||
def construct(self, x):
|
||||
res = self.param_tuple[0] + self.param_tuple[1] + self.param_list[0] + self.param_listp[1] + x
|
||||
return res
|
||||
|
||||
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
|
||||
net = ParamNet()
|
||||
x = Tensor([10], ms.float32)
|
||||
output = net(x)
|
||||
output_expect = Tensor(20, ms.float32)
|
||||
assert output == output_expect
|
||||
|
|
|
@ -0,0 +1,387 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import ParameterTuple
|
||||
from mindspore import Tensor, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_1_1():
|
||||
"""
|
||||
Feature: Check the names of parameters and the names of inputs of construct.
|
||||
Description: If the name of the input of construct is same as the parameters, add suffix to the name of the input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
|
||||
|
||||
def construct(self, name_a):
|
||||
return self.param_a + self.param_b - name_a
|
||||
|
||||
net = ParamNet()
|
||||
res = net(Tensor([3], ms.float32))
|
||||
assert res == 0
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_1_2():
|
||||
"""
|
||||
Feature: Check the names of parameters and the names of inputs of construct.
|
||||
Description: If the name of the input of construct is same as the parameters, add suffix to the name of the input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_b = ParameterTuple((Parameter(Tensor([2], ms.float32), name="name_b"), self.param_a))
|
||||
|
||||
def construct(self, name_b):
|
||||
return self.param_a + self.param_b[0] - name_b
|
||||
|
||||
net = ParamNet()
|
||||
res = net(Tensor([3], ms.float32))
|
||||
assert res == 0
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_2_1():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: If parameters in init have same name, an exception will be thrown.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_b = Parameter(Tensor([2], ms.float32), name="name_a")
|
||||
|
||||
def construct(self):
|
||||
return self.param_a + self.param_b
|
||||
|
||||
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
|
||||
net = ParamNet()
|
||||
res = net()
|
||||
assert res == 3
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_2_2():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in init.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.res1 = ParameterTuple((Parameter(Tensor([2], ms.float32)), self.param_a))
|
||||
self.param_a = Parameter(Tensor([3], ms.float32), name="name_a")
|
||||
self.res2 = self.res1[0] + self.param_a
|
||||
|
||||
def construct(self):
|
||||
return self.param_a + self.res1 + self.res2
|
||||
|
||||
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
|
||||
net = ParamNet()
|
||||
res = net()
|
||||
assert res == 10
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_3():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in init.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32))
|
||||
self.param_b = Parameter(Tensor([2], ms.float32))
|
||||
|
||||
def construct(self):
|
||||
return self.param_a + self.param_b
|
||||
|
||||
net = ParamNet()
|
||||
res = net()
|
||||
assert res == 3
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_4():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in init.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.res1 = ParameterTuple((Parameter(Tensor([2], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([4], ms.float32), name="name_a")))
|
||||
|
||||
def construct(self):
|
||||
return self.res1[0] + self.res1[1]
|
||||
|
||||
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
|
||||
net = ParamNet()
|
||||
res = net()
|
||||
assert res == 6
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_5_1():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in init.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.res1 = ParameterTuple((Parameter(Tensor([2], ms.float32)), Parameter(Tensor([4], ms.float32))))
|
||||
|
||||
def construct(self):
|
||||
return self.res1[0] + self.res1[1]
|
||||
|
||||
with pytest.raises(ValueError, match="its name 'Parameter' already exists."):
|
||||
net = ParamNet()
|
||||
res = net()
|
||||
assert res == 6
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_5_2():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in init.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.res1 = ParameterTuple((Parameter(Tensor([2], ms.float32)), self.param_a))
|
||||
self.param_a = Parameter(Tensor([3], ms.float32), name="name_b")
|
||||
self.res2 = self.res1[0] + self.param_a
|
||||
|
||||
def construct(self):
|
||||
return self.param_a + self.res1[0] + self.res2
|
||||
|
||||
net = ParamNet()
|
||||
res = net()
|
||||
assert res == 10
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_list_tuple_no_name():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in init.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_tuple = (Parameter(Tensor([5], ms.float32)), Parameter(Tensor([6], ms.float32)))
|
||||
self.param_list = [Parameter(Tensor([7], ms.float32)), Parameter(Tensor([8], ms.float32))]
|
||||
|
||||
def construct(self):
|
||||
return self.param_tuple[0] + self.param_tuple[1] + self.param_list[0] + self.param_list[1]
|
||||
|
||||
net = ParamNet()
|
||||
res = net()
|
||||
assert res == 26
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_in_tuple():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in init.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
|
||||
self.param_tuple = ParameterTuple((self.param_a, self.param_b))
|
||||
|
||||
def construct(self):
|
||||
return self.param_a + self.param_b + self.param_tuple[0] + self.param_tuple[1]
|
||||
|
||||
net = ParamNet()
|
||||
res = net()
|
||||
assert res == 6
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_parameter_tuple_1():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in init.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_tuple = ParameterTuple((Parameter(Tensor([5], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([5], ms.float32), name="name_b")))
|
||||
|
||||
def construct(self):
|
||||
return self.param_a + self.param_tuple[0] + self.param_tuple[1]
|
||||
|
||||
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
|
||||
net = ParamNet()
|
||||
res = net()
|
||||
assert res == 7
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_parameter_tuple_2():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in init.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_tuple = ParameterTuple((self.param_a, self.param_a, self.param_a))
|
||||
|
||||
def construct(self):
|
||||
return self.param_a + self.param_tuple[0] + self.param_tuple[1] + self.param_tuple[2]
|
||||
|
||||
net = ParamNet()
|
||||
res = net()
|
||||
assert res == 4
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: If parameter in list or tuple is not given a name, will give it a unique name.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
|
||||
self.param_c = Parameter(Tensor([3], ms.float32))
|
||||
self.param_d = Parameter(Tensor([4], ms.float32))
|
||||
self.param_tuple = (Parameter(Tensor([5], ms.float32)),
|
||||
Parameter(Tensor([6], ms.float32)))
|
||||
self.param_list = [Parameter(Tensor([5], ms.float32)),
|
||||
Parameter(Tensor([6], ms.float32))]
|
||||
|
||||
def construct(self, x):
|
||||
res1 = self.param_a + self.param_b + self.param_c + self.param_d
|
||||
res1 = res1 - self.param_list[0] + self.param_list[1] + x
|
||||
res2 = self.param_list[0] + self.param_list[1]
|
||||
return res1, res2
|
||||
|
||||
net = ParamNet()
|
||||
x = Tensor([10], ms.float32)
|
||||
output1, output2 = net(x)
|
||||
output1_expect = Tensor(21, ms.float32)
|
||||
output2_expect = Tensor(11, ms.float32)
|
||||
assert output1 == output1_expect
|
||||
assert output2 == output2_expect
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_same_name_between_tuple_or_list():
|
||||
"""
|
||||
Feature: Check the names of parameters between tuple or list.
|
||||
Description: If the same name exists between tuple and list, an exception will be thrown.
|
||||
Expectation: Get the expected exception report.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_tuple = (Parameter(Tensor([1], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([2], ms.float32)))
|
||||
self.param_list = [Parameter(Tensor([3], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([4], ms.float32))]
|
||||
|
||||
def construct(self, x):
|
||||
res = self.param_tuple[0] + self.param_tuple[1] + self.param_list[0] + self.param_listp[1] + x
|
||||
return res
|
||||
|
||||
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
|
||||
net = ParamNet()
|
||||
x = Tensor([10], ms.float32)
|
||||
output = net(x)
|
||||
output_expect = Tensor(20, ms.float32)
|
||||
assert output == output_expect
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common.parameter import ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class FullyConnectedNet(nn.Cell):
|
||||
def __init__(self, input_size, hidden_size, output_size):
|
||||
super(FullyConnectedNet, self).__init__(auto_prefix=False)
|
||||
self.linear1 = nn.Dense(input_size, hidden_size, weight_init="XavierUniform")
|
||||
self.linear2 = nn.Dense(hidden_size, output_size, weight_init="XavierUniform")
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(self.linear1(x))
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
class EmaUpdate(nn.Cell):
|
||||
def __init__(self, policy_net, target_net, tau, period):
|
||||
super(EmaUpdate, self).__init__()
|
||||
self.tau = tau
|
||||
self.period = period
|
||||
# Use CellList manage parameters of multiple cells
|
||||
self.cell_list = nn.CellList()
|
||||
self.cell_list.append(policy_net)
|
||||
self.cell_list.append(target_net)
|
||||
self.policy_param = ParameterTuple(self.cell_list[0].get_parameters())
|
||||
self.target_param = ParameterTuple(self.cell_list[1].get_parameters())
|
||||
self.step = Parameter(initializer(0, [1]), name='step', requires_grad=False)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.assignadd = P.AssignAdd()
|
||||
|
||||
def ema(self, tau, policy_param, target_param):
|
||||
new_param = (1 - tau) * target_param + tau * policy_param
|
||||
out = P.Assign()(target_param, new_param)
|
||||
return out
|
||||
|
||||
def construct(self):
|
||||
if self.step % self.period == 0:
|
||||
self.hyper_map(F.partial(self.ema, self.tau), self.policy_param, self.target_param)
|
||||
return self.assignadd(self.step, 1)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_target_update():
|
||||
"""
|
||||
Feature: manage parameters with CellList.
|
||||
Description: Check the name of parameter in CellList.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
policy_net = FullyConnectedNet(4, 100, 2)
|
||||
target_net = FullyConnectedNet(4, 100, 2)
|
||||
tau = 0.2
|
||||
tau_tensor = Tensor(np.array([tau], dtype=np.float32))
|
||||
ema_update = EmaUpdate(policy_net, target_net, tau_tensor, period=1)
|
||||
ema_update()
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, ms_function
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import ParameterTuple
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_ms_function_1():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in ms_function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
param_b = Parameter(Tensor([2], ms.float32), name="name_a")
|
||||
|
||||
@ms_function
|
||||
def test_parameter_ms_function():
|
||||
return param_a + param_b
|
||||
|
||||
with pytest.raises(RuntimeError, match="its name 'name_a' already exists."):
|
||||
res = test_parameter_ms_function()
|
||||
assert res == 3
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_ms_function_2():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in ms_function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
param_b = param_a
|
||||
|
||||
@ms_function
|
||||
def test_parameter_ms_function():
|
||||
return param_a + param_b
|
||||
|
||||
res = test_parameter_ms_function()
|
||||
assert res == 2
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_ms_function_3():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in ms_function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
param_a = Parameter(Tensor([1], ms.float32))
|
||||
param_b = Parameter(Tensor([2], ms.float32))
|
||||
|
||||
@ms_function
|
||||
def test_parameter_ms_function():
|
||||
return param_a + param_b
|
||||
|
||||
with pytest.raises(RuntimeError, match="its name 'Parameter' already exists."):
|
||||
res = test_parameter_ms_function()
|
||||
assert res == 3
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_ms_function_4():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in ms_function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
|
||||
param_a = ParameterTuple((Parameter(Tensor([1], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([2], ms.float32), name="name_a")))
|
||||
|
||||
@ms_function
|
||||
def test_parameter_ms_function():
|
||||
return param_a[0] + param_a[1]
|
||||
|
||||
res = test_parameter_ms_function()
|
||||
assert res == 3
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_ms_function_5():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: Check the name of parameter in ms_function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="its name 'Parameter' already exists."):
|
||||
param_a = ParameterTuple((Parameter(Tensor([1], ms.float32)), Parameter(Tensor([2], ms.float32))))
|
||||
|
||||
@ms_function
|
||||
def test_parameter_ms_function():
|
||||
return param_a[0] + param_a[1]
|
||||
|
||||
res = test_parameter_ms_function()
|
||||
assert res == 3
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -26,7 +26,7 @@ class IterableObjc:
|
|||
cont = 0
|
||||
while cont < 3:
|
||||
cont += 1
|
||||
yield Parameter(Tensor(cont), name="cont")
|
||||
yield Parameter(Tensor(cont), name="cont" + str(cont))
|
||||
|
||||
|
||||
params = IterableObjc()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -29,7 +29,7 @@ class Net(Cell):
|
|||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.bias_add = P.BiasAdd().shard(strategy2)
|
||||
self.add_weight = Parameter(add_weight, "w1")
|
||||
self.mul_weight = Parameter(matmul_weight, "w1")
|
||||
self.mul_weight = Parameter(matmul_weight, "w2")
|
||||
self.bias = Parameter(bias, "bias")
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
|
|
Loading…
Reference in New Issue