restricting modify non_Parameter class members

This commit is contained in:
buxue 2020-06-02 20:09:35 +08:00
parent 5d397d8404
commit 94c9019d8e
3 changed files with 64 additions and 9 deletions

View File

@ -1175,11 +1175,11 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob
auto filename = location[0].cast<std::string>(); auto filename = location[0].cast<std::string>();
auto line_no = location[1].cast<int>(); auto line_no = location[1].cast<int>();
// Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
if (!py::hasattr(ast()->obj(), attr_name.c_str())) { if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":" MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":"
<< line_no; << line_no;
} }
auto obj = ast()->obj().attr(attr_name.c_str()); auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
auto obj_type = obj.attr("__class__").attr("__name__"); auto obj_type = obj.attr("__class__").attr("__name__");
if (!py::hasattr(obj, "__parameter__")) { if (!py::hasattr(obj, "__parameter__")) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
@ -1205,8 +1205,18 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
// getitem apply should return the sequence data structure itself // getitem apply should return the sequence data structure itself
std::string var_name = ""; std::string var_name = "";
if (ast_->IsClassMember(value_obj)) { if (ast_->IsClassMember(value_obj)) {
var_name = "self."; std::string attr_name = value_obj.attr("attr").cast<std::string>();
(void)var_name.append(value_obj.attr("attr").cast<std::string>()); var_name = "self." + attr_name;
if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function.";
}
auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
auto obj_type = obj.attr("__class__").attr("__name__");
if (!py::hasattr(obj, "__parameter__")) {
MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '"
<< py::str(obj).cast<std::string>() << "' with type '"
<< py::str(obj_type).cast<std::string>() << "'.";
}
} else { } else {
var_name = value_obj.attr("id").cast<std::string>(); var_name = value_obj.attr("id").cast<std::string>();
} }
@ -1231,7 +1241,7 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
} }
} }
// process a assign statement , such as a =b, a,b = tup // process a assign statement, such as a =b, a,b = tup
FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) { FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast assgin"; MS_LOG(DEBUG) << "Process ast assgin";
py::object value_object = python_adapter::GetPyObjAttr(node, "value"); py::object value_object = python_adapter::GetPyObjAttr(node, "value");

View File

@ -17,6 +17,7 @@
@Desc : test_dictionary @Desc : test_dictionary
""" """
import numpy as np import numpy as np
import pytest
from mindspore import Tensor, context from mindspore import Tensor, context
from mindspore.nn import Cell from mindspore.nn import Cell
@ -89,7 +90,9 @@ def test_dict_set_or_get_item():
return ret return ret
net = DictNet() net = DictNet()
assert net() == (88, 99, 4, 5, 6) with pytest.raises(TypeError) as ex:
net()
assert "'self.dict_' should be a Parameter" in str(ex.value)
def test_dict_set_or_get_item_2(): def test_dict_set_or_get_item_2():
@ -135,7 +138,9 @@ def test_dict_set_or_get_item_3():
return self.dict_["x"] return self.dict_["x"]
net = DictNet() net = DictNet()
assert net() == Tensor(np.ones([4, 2, 3], np.float32)) with pytest.raises(TypeError) as ex:
net()
assert "'self.dict_' should be a Parameter" in str(ex.value)
def test_dict_set_item(): def test_dict_set_item():

View File

@ -15,6 +15,7 @@
import functools import functools
import numpy as np import numpy as np
import pytest
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
@ -24,6 +25,7 @@ from tests.mindspore_test_framework.mindspore_test import mindspore_test
from tests.mindspore_test_framework.pipeline.forward.compile_forward \ from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
context.set_context(mode=context.GRAPH_MODE)
def test_list_equal(): def test_list_equal():
class Net(nn.Cell): class Net(nn.Cell):
@ -109,7 +111,7 @@ def test_list_append():
assert net(x, y) == y assert net(x, y) == y
def test_list_append_2(): def test_class_member_list_append():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, z: list): def __init__(self, z: list):
super(Net, self).__init__() super(Net, self).__init__()
@ -129,7 +131,45 @@ def test_list_append_2():
y = Tensor(np.zeros([3, 4, 5], np.int32)) y = Tensor(np.zeros([3, 4, 5], np.int32))
z = [[1, 2], 3] z = [[1, 2], 3]
net = Net(z) net = Net(z)
assert net(x, y) == x with pytest.raises(TypeError) as ex:
net(x, y)
assert "'self.z' should be a Parameter, but got '[[1, 2], 3]' with type 'list'." in str(ex.value)
def test_class_member_not_defined():
class Net(nn.Cell):
def __init__(self, z: list):
super(Net, self).__init__()
self.z = z
def construct(self, x, y):
self.x[0] = 9
return self.x
z = [[1, 2], 3]
net = Net(z)
with pytest.raises(TypeError) as ex:
net()
assert "'self.x' was not defined in the class '__init__' function." in str(ex.value)
def test_change_list_element():
class Net(nn.Cell):
def __init__(self, z: list):
super(Net, self).__init__()
self.z = z
def construct(self, x, y):
self.z[0] = x
return self.z[0]
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.zeros([3, 4, 5], np.int32))
z = [[1, 2], 3]
net = Net(z)
with pytest.raises(TypeError) as ex:
net(x, y)
assert "'self.z' should be a Parameter, but got '[[1, 2], 3]' with type 'list'." in str(ex.value)
class ListOperate(nn.Cell): class ListOperate(nn.Cell):