restricting modify non_Parameter class members

This commit is contained in:
buxue 2020-06-02 20:09:35 +08:00
parent 5d397d8404
commit 94c9019d8e
3 changed files with 64 additions and 9 deletions

View File

@ -1175,11 +1175,11 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
auto filename = location[0].cast<std::string>();
auto line_no = location[1].cast<int>();
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
if (!py::hasattr(ast()->obj(), attr_name.c_str())) {
if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":"
<< line_no;
}
auto obj = ast()->obj().attr(attr_name.c_str());
auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
auto obj_type = obj.attr("__class__").attr("__name__");
if (!py::hasattr(obj, "__parameter__")) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
@ -1205,8 +1205,18 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
// getitem apply should return the sequence data structure itself
std::string var_name = "";
if (ast_->IsClassMember(value_obj)) {
var_name = "self.";
(void)var_name.append(value_obj.attr("attr").cast<std::string>());
std::string attr_name = value_obj.attr("attr").cast<std::string>();
var_name = "self." + attr_name;
if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function.";
}
auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
auto obj_type = obj.attr("__class__").attr("__name__");
if (!py::hasattr(obj, "__parameter__")) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
<< py::str(obj).cast<std::string>() << "' with type '"
<< py::str(obj_type).cast<std::string>() << "'.";
}
} else {
var_name = value_obj.attr("id").cast<std::string>();
}
@ -1231,7 +1241,7 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
}
}
// process a assign statement , such as a =b, a,b = tup
// process a assign statement, such as a =b, a,b = tup
FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast assgin";
py::object value_object = python_adapter::GetPyObjAttr(node, "value");

View File

@ -17,6 +17,7 @@
@Desc : test_dictionary
"""
import numpy as np
import pytest
from mindspore import Tensor, context
from mindspore.nn import Cell
@ -89,7 +90,9 @@ def test_dict_set_or_get_item():
return ret
net = DictNet()
assert net() == (88, 99, 4, 5, 6)
with pytest.raises(TypeError) as ex:
net()
assert "'self.dict_' should be a Parameter" in str(ex.value)
def test_dict_set_or_get_item_2():
@ -135,7 +138,9 @@ def test_dict_set_or_get_item_3():
return self.dict_["x"]
net = DictNet()
assert net() == Tensor(np.ones([4, 2, 3], np.float32))
with pytest.raises(TypeError) as ex:
net()
assert "'self.dict_' should be a Parameter" in str(ex.value)
def test_dict_set_item():

View File

@ -15,6 +15,7 @@
import functools
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
@ -24,6 +25,7 @@ from tests.mindspore_test_framework.mindspore_test import mindspore_test
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
context.set_context(mode=context.GRAPH_MODE)
def test_list_equal():
class Net(nn.Cell):
@ -109,7 +111,7 @@ def test_list_append():
assert net(x, y) == y
def test_list_append_2():
def test_class_member_list_append():
class Net(nn.Cell):
def __init__(self, z: list):
super(Net, self).__init__()
@ -129,7 +131,45 @@ def test_list_append_2():
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = [[1, 2], 3]
net = Net(z)
assert net(x, y) == x
with pytest.raises(TypeError) as ex:
net(x, y)
assert "'self.z' should be a Parameter, but got '[[1, 2], 3]' with type 'list'." in str(ex.value)
def test_class_member_not_defined():
class Net(nn.Cell):
def __init__(self, z: list):
super(Net, self).__init__()
self.z = z
def construct(self, x, y):
self.x[0] = 9
return self.x
z = [[1, 2], 3]
net = Net(z)
with pytest.raises(TypeError) as ex:
net()
assert "'self.x' was not defined in the class '__init__' function." in str(ex.value)
def test_change_list_element():
class Net(nn.Cell):
def __init__(self, z: list):
super(Net, self).__init__()
self.z = z
def construct(self, x, y):
self.z[0] = x
return self.z[0]
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = [[1, 2], 3]
net = Net(z)
with pytest.raises(TypeError) as ex:
net(x, y)
assert "'self.z' should be a Parameter, but got '[[1, 2], 3]' with type 'list'." in str(ex.value)
class ListOperate(nn.Cell):