[ME] Add parameter name check.

This commit is contained in:
Margaret_wangrui 2021-11-22 20:22:27 +08:00
parent 33bc6978a0
commit 462c38813d
10 changed files with 646 additions and 98 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -74,6 +74,13 @@ abstract::AbstractBasePtr ClassType::ToAbstract() {
}
namespace {
std::string GetPyObjId(const py::object &obj) {
py::object out = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
if (py::isinstance<py::none>(out)) {
MS_LOG(EXCEPTION) << "Get pyobj failed";
}
return out.cast<std::string>();
}
// If any mixed precision flag add a cast node after the parameter node.
// argument obj should be python Parameter object
// it will be converted to Parameter node here
@ -94,23 +101,40 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
if (py::isinstance<py::none>(name_attr)) {
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
}
auto obj_id = GetPyObjId(obj);
static std::vector<std::string> param_obj_ids;
auto param_name = py::cast<std::string>(name_attr);
auto top_func_graph = Parser::GetTopFuncGraph();
// If the parameter node has been created , return it.
AnfNodePtr para_node = nullptr;
for (auto const &param : top_func_graph->parameters()) {
auto param_node = dyn_cast<Parameter>(param);
if (param_node != nullptr && param_node->name() == param_name && !param_node->is_top_graph_param()) {
para_node = param;
MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString()
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
break;
if (param_node != nullptr && param_node->name() == param_name) {
if (param_node->is_top_graph_param()) {
// If the name of the input of construct is same as the parameters,
// add suffix to the name of the input of construct.
string suffix_name = param_node->name() + "_$";
param_node->set_name(suffix_name);
param_node->debug_info()->set_name(suffix_name);
MS_LOG(DEBUG) << "Add suffix to the name of the input of construct " << func_graph->ToString()
<< ", input: " << param_node->DebugString();
} else {
// Exist two parameter object which name is the same.
if (std::find(param_obj_ids.begin(), param_obj_ids.end(), obj_id) == param_obj_ids.end()) {
MS_LOG(EXCEPTION) << "The parameter " << param_node->DebugString() << " , its name '" << param_name
<< "' already exists. Please set a unique name for the parameter.";
}
para_node = param;
MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString()
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
break;
}
}
}
if (para_node == nullptr) {
auto node = top_func_graph->AddWeightParameter(param_name);
auto value = py::cast<tensor::MetaTensorPtr>(obj);
param_obj_ids.emplace_back(obj_id);
node->set_default_param(value);
// Set abstract for parameter
auto abs = value->ToAbstract();

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -658,24 +658,17 @@ class ParameterTuple(tuple):
"""Create instance object of ParameterTuple."""
data = tuple(iterable)
ids = set()
orders = {}
names = set()
for x in data:
if not isinstance(x, Parameter):
raise TypeError(f"ParameterTuple input should be `Parameter` collection."
f"But got a {type(iterable)}, {iterable}")
if id(x) not in ids:
if x.name in names:
raise ValueError("The value {} , its name '{}' already exists. "
"Please set a unique name for the parameter.".format(x, x.name))
names.add(x.name)
ids.add(id(x))
if x.name not in orders.keys():
orders[x.name] = [0, x]
else:
if isinstance(orders[x.name], list):
name = x.name
orders[name][1].name = name + "_" + str(0)
x.name = x.name + "_" + str(1)
orders[name] = 1
else:
orders[x.name] += 1
x.name = x.name + "_" + str(orders[x.name])
return tuple.__new__(ParameterTuple, tuple(data))
def clone(self, prefix, init='same'):

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -565,8 +565,8 @@ class Cell(Cell_):
self._id += 1
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
if item.name in exist_names:
raise ValueError("The value {} , its name '{}' already exists.".
format(value, item.name))
raise ValueError("The value {} , its name '{}' already exists. "
"Please set a unique name for the parameter.".format(value, item.name))
exist_names.add(item.name)
if context._get_mode() == context.PYNATIVE_MODE:
@ -589,8 +589,8 @@ class Cell(Cell_):
item.name = item.name + "$" + str(self._id)
self._id += 1
if item.name in self.exist_names:
raise ValueError("The value {} , its name '{}' already exists.".
format(value, item.name))
raise ValueError("The value {} , its name '{}' already exists. "
"Please set a unique name for the parameter.".format(value, item.name))
self.exist_names.add(item.name)
object.__setattr__(self, name, value)
@ -1114,8 +1114,8 @@ class Cell(Cell_):
names = set("")
for value, param in self.parameters_and_names():
if param.name in names:
raise ValueError("The value of {} is {}, its name '{}' already exists.".
format(value, param, param.name))
raise ValueError("The value of {} is {}, its name '{}' already exists. "
"Please set a unique name for the parameter.". format(value, param, param.name))
names.add(param.name)
def parameters_and_names(self, name_prefix='', expand=True):

View File

@ -732,70 +732,3 @@ def test_assign_in_zip_loop():
net = AssignInZipLoop()
out = net(x)
assert np.all(out.asnumpy() == 1)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter():
"""
Feature: Check the names of parameters.
Description: If parameter in list or tuple is not given a name, will give it a unique name.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
self.param_c = Parameter(Tensor([3], ms.float32))
self.param_d = Parameter(Tensor([4], ms.float32))
self.param_tuple = (Parameter(Tensor([5], ms.float32)),
Parameter(Tensor([6], ms.float32)))
self.param_list = [Parameter(Tensor([5], ms.float32)),
Parameter(Tensor([6], ms.float32))]
def construct(self, x):
res1 = self.param_a + self.param_b + self.param_c + self.param_d
res1 = res1 - self.param_list[0] + self.param_list[1] + x
res2 = self.param_list[0] + self.param_list[1]
return res1, res2
net = ParamNet()
x = Tensor([10], ms.float32)
output1, output2 = net(x)
output1_expect = Tensor(21, ms.float32)
output2_expect = Tensor(11, ms.float32)
assert output1 == output1_expect
assert output2 == output2_expect
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter_same_name_between_tuple_or_list():
"""
Feature: Check the names of parameters between tuple or list.
Description: If the same name exists between tuple and list, an exception will be thrown.
Expectation: Get the expected exception report.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_tuple = (Parameter(Tensor([1], ms.float32), name="name_a"),
Parameter(Tensor([2], ms.float32)))
self.param_list = [Parameter(Tensor([3], ms.float32), name="name_a"),
Parameter(Tensor([4], ms.float32))]
def construct(self, x):
res = self.param_tuple[0] + self.param_tuple[1] + self.param_list[0] + self.param_listp[1] + x
return res
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
net = ParamNet()
x = Tensor([10], ms.float32)
output = net(x)
output_expect = Tensor(20, ms.float32)
assert output == output_expect

View File

View File

@ -0,0 +1,387 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore as ms
from mindspore.nn import Cell
from mindspore.common.parameter import Parameter
from mindspore.common import ParameterTuple
from mindspore import Tensor, context
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_1_1():
"""
Feature: Check the names of parameters and the names of inputs of construct.
Description: If the name of the input of construct is same as the parameters, add suffix to the name of the input.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
def construct(self, name_a):
return self.param_a + self.param_b - name_a
net = ParamNet()
res = net(Tensor([3], ms.float32))
assert res == 0
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_1_2():
"""
Feature: Check the names of parameters and the names of inputs of construct.
Description: If the name of the input of construct is same as the parameters, add suffix to the name of the input.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_b = ParameterTuple((Parameter(Tensor([2], ms.float32), name="name_b"), self.param_a))
def construct(self, name_b):
return self.param_a + self.param_b[0] - name_b
net = ParamNet()
res = net(Tensor([3], ms.float32))
assert res == 0
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_2_1():
"""
Feature: Check the names of parameters.
Description: If parameters in init have same name, an exception will be thrown.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_b = Parameter(Tensor([2], ms.float32), name="name_a")
def construct(self):
return self.param_a + self.param_b
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
net = ParamNet()
res = net()
assert res == 3
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_2_2():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.res1 = ParameterTuple((Parameter(Tensor([2], ms.float32)), self.param_a))
self.param_a = Parameter(Tensor([3], ms.float32), name="name_a")
self.res2 = self.res1[0] + self.param_a
def construct(self):
return self.param_a + self.res1 + self.res2
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
net = ParamNet()
res = net()
assert res == 10
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_3():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32))
self.param_b = Parameter(Tensor([2], ms.float32))
def construct(self):
return self.param_a + self.param_b
net = ParamNet()
res = net()
assert res == 3
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_4():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.res1 = ParameterTuple((Parameter(Tensor([2], ms.float32), name="name_a"),
Parameter(Tensor([4], ms.float32), name="name_a")))
def construct(self):
return self.res1[0] + self.res1[1]
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
net = ParamNet()
res = net()
assert res == 6
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_5_1():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.res1 = ParameterTuple((Parameter(Tensor([2], ms.float32)), Parameter(Tensor([4], ms.float32))))
def construct(self):
return self.res1[0] + self.res1[1]
with pytest.raises(ValueError, match="its name 'Parameter' already exists."):
net = ParamNet()
res = net()
assert res == 6
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_5_2():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.res1 = ParameterTuple((Parameter(Tensor([2], ms.float32)), self.param_a))
self.param_a = Parameter(Tensor([3], ms.float32), name="name_b")
self.res2 = self.res1[0] + self.param_a
def construct(self):
return self.param_a + self.res1[0] + self.res2
net = ParamNet()
res = net()
assert res == 10
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_list_tuple_no_name():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_tuple = (Parameter(Tensor([5], ms.float32)), Parameter(Tensor([6], ms.float32)))
self.param_list = [Parameter(Tensor([7], ms.float32)), Parameter(Tensor([8], ms.float32))]
def construct(self):
return self.param_tuple[0] + self.param_tuple[1] + self.param_list[0] + self.param_list[1]
net = ParamNet()
res = net()
assert res == 26
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_in_tuple():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
self.param_tuple = ParameterTuple((self.param_a, self.param_b))
def construct(self):
return self.param_a + self.param_b + self.param_tuple[0] + self.param_tuple[1]
net = ParamNet()
res = net()
assert res == 6
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_parameter_tuple_1():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_tuple = ParameterTuple((Parameter(Tensor([5], ms.float32), name="name_a"),
Parameter(Tensor([5], ms.float32), name="name_b")))
def construct(self):
return self.param_a + self.param_tuple[0] + self.param_tuple[1]
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
net = ParamNet()
res = net()
assert res == 7
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_parameter_tuple_2():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in init.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_tuple = ParameterTuple((self.param_a, self.param_a, self.param_a))
def construct(self):
return self.param_a + self.param_tuple[0] + self.param_tuple[1] + self.param_tuple[2]
net = ParamNet()
res = net()
assert res == 4
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter():
"""
Feature: Check the names of parameters.
Description: If parameter in list or tuple is not given a name, will give it a unique name.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
self.param_c = Parameter(Tensor([3], ms.float32))
self.param_d = Parameter(Tensor([4], ms.float32))
self.param_tuple = (Parameter(Tensor([5], ms.float32)),
Parameter(Tensor([6], ms.float32)))
self.param_list = [Parameter(Tensor([5], ms.float32)),
Parameter(Tensor([6], ms.float32))]
def construct(self, x):
res1 = self.param_a + self.param_b + self.param_c + self.param_d
res1 = res1 - self.param_list[0] + self.param_list[1] + x
res2 = self.param_list[0] + self.param_list[1]
return res1, res2
net = ParamNet()
x = Tensor([10], ms.float32)
output1, output2 = net(x)
output1_expect = Tensor(21, ms.float32)
output2_expect = Tensor(11, ms.float32)
assert output1 == output1_expect
assert output2 == output2_expect
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter_same_name_between_tuple_or_list():
"""
Feature: Check the names of parameters between tuple or list.
Description: If the same name exists between tuple and list, an exception will be thrown.
Expectation: Get the expected exception report.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_tuple = (Parameter(Tensor([1], ms.float32), name="name_a"),
Parameter(Tensor([2], ms.float32)))
self.param_list = [Parameter(Tensor([3], ms.float32), name="name_a"),
Parameter(Tensor([4], ms.float32))]
def construct(self, x):
res = self.param_tuple[0] + self.param_tuple[1] + self.param_list[0] + self.param_listp[1] + x
return res
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
net = ParamNet()
x = Tensor([10], ms.float32)
output = net(x)
output_expect = Tensor(20, ms.float32)
assert output == output_expect

View File

@ -0,0 +1,82 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore import Tensor, context
from mindspore.common.parameter import ParameterTuple, Parameter
from mindspore.common.initializer import initializer
context.set_context(mode=context.GRAPH_MODE)
class FullyConnectedNet(nn.Cell):
def __init__(self, input_size, hidden_size, output_size):
super(FullyConnectedNet, self).__init__(auto_prefix=False)
self.linear1 = nn.Dense(input_size, hidden_size, weight_init="XavierUniform")
self.linear2 = nn.Dense(hidden_size, output_size, weight_init="XavierUniform")
self.relu = nn.ReLU()
def construct(self, x):
x = self.relu(self.linear1(x))
x = self.linear2(x)
return x
class EmaUpdate(nn.Cell):
def __init__(self, policy_net, target_net, tau, period):
super(EmaUpdate, self).__init__()
self.tau = tau
self.period = period
# Use CellList manage parameters of multiple cells
self.cell_list = nn.CellList()
self.cell_list.append(policy_net)
self.cell_list.append(target_net)
self.policy_param = ParameterTuple(self.cell_list[0].get_parameters())
self.target_param = ParameterTuple(self.cell_list[1].get_parameters())
self.step = Parameter(initializer(0, [1]), name='step', requires_grad=False)
self.hyper_map = C.HyperMap()
self.assignadd = P.AssignAdd()
def ema(self, tau, policy_param, target_param):
new_param = (1 - tau) * target_param + tau * policy_param
out = P.Assign()(target_param, new_param)
return out
def construct(self):
if self.step % self.period == 0:
self.hyper_map(F.partial(self.ema, self.tau), self.policy_param, self.target_param)
return self.assignadd(self.step, 1)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_target_update():
"""
Feature: manage parameters with CellList.
Description: Check the name of parameter in CellList.
Expectation: No exception.
"""
policy_net = FullyConnectedNet(4, 100, 2)
target_net = FullyConnectedNet(4, 100, 2)
tau = 0.2
tau_tensor = Tensor(np.array([tau], dtype=np.float32))
ema_update = EmaUpdate(policy_net, target_net, tau_tensor, period=1)
ema_update()

View File

@ -0,0 +1,129 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore as ms
from mindspore import context, Tensor, ms_function
from mindspore.common.parameter import Parameter
from mindspore.common import ParameterTuple
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_ms_function_1():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in ms_function.
Expectation: No exception.
"""
param_a = Parameter(Tensor([1], ms.float32), name="name_a")
param_b = Parameter(Tensor([2], ms.float32), name="name_a")
@ms_function
def test_parameter_ms_function():
return param_a + param_b
with pytest.raises(RuntimeError, match="its name 'name_a' already exists."):
res = test_parameter_ms_function()
assert res == 3
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_ms_function_2():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in ms_function.
Expectation: No exception.
"""
param_a = Parameter(Tensor([1], ms.float32), name="name_a")
param_b = param_a
@ms_function
def test_parameter_ms_function():
return param_a + param_b
res = test_parameter_ms_function()
assert res == 2
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_ms_function_3():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in ms_function.
Expectation: No exception.
"""
param_a = Parameter(Tensor([1], ms.float32))
param_b = Parameter(Tensor([2], ms.float32))
@ms_function
def test_parameter_ms_function():
return param_a + param_b
with pytest.raises(RuntimeError, match="its name 'Parameter' already exists."):
res = test_parameter_ms_function()
assert res == 3
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_ms_function_4():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in ms_function.
Expectation: No exception.
"""
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
param_a = ParameterTuple((Parameter(Tensor([1], ms.float32), name="name_a"),
Parameter(Tensor([2], ms.float32), name="name_a")))
@ms_function
def test_parameter_ms_function():
return param_a[0] + param_a[1]
res = test_parameter_ms_function()
assert res == 3
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parameter_ms_function_5():
"""
Feature: Check the names of parameters.
Description: Check the name of parameter in ms_function.
Expectation: No exception.
"""
with pytest.raises(ValueError, match="its name 'Parameter' already exists."):
param_a = ParameterTuple((Parameter(Tensor([1], ms.float32)), Parameter(Tensor([2], ms.float32))))
@ms_function
def test_parameter_ms_function():
return param_a[0] + param_a[1]
res = test_parameter_ms_function()
assert res == 3

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -26,7 +26,7 @@ class IterableObjc:
cont = 0
while cont < 3:
cont += 1
yield Parameter(Tensor(cont), name="cont")
yield Parameter(Tensor(cont), name="cont" + str(cont))
params = IterableObjc()

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,7 +29,7 @@ class Net(Cell):
self.matmul = P.MatMul().shard(strategy1)
self.bias_add = P.BiasAdd().shard(strategy2)
self.add_weight = Parameter(add_weight, "w1")
self.mul_weight = Parameter(matmul_weight, "w1")
self.mul_weight = Parameter(matmul_weight, "w2")
self.bias = Parameter(bias, "bias")
self.reshape = P.Reshape()