!43651 报错不合理整改,增加测试用例

Merge pull request !43651 from ligan/cell
This commit is contained in:
i-robot 2022-10-17 10:01:06 +00:00 committed by Gitee
commit 1b82179635
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 72 additions and 3 deletions

View File

@ -116,6 +116,9 @@
"mindspore/tests/ut/python/mindrecord/test_mindrecord_exception.py" "redefined-outer-name"
"mindspore/tests/ut/python/mindrecord/test_mnist_to_mr.py" "redefined-outer-name"
"mindspore/tests/ut/python/nn/test_batchnorm.py" "no-value-for-parameter"
"mindspore/tests/ut/python/nn/test_cell_method_attribute.py" "no-method-argument"
"mindspore/tests/ut/python/nn/test_cell_method_attribute.py" "too-many-function-args"
"mindspore/tests/ut/python/nn/test_cell_method_attribute.py" "no-self-argument"
"mindspore/tests/ut/python/onnx/test_onnx.py" "unused-variable"
"mindspore/tests/ut/python/ops" "super-init-not-called"
"mindspore/tests/ut/python/ops/test_tensor_slice.py" "redefined-outer-name"

View File

@ -329,6 +329,9 @@ class Cell(Cell_):
if context.get_context is not None and context._get_mode() == context.PYNATIVE_MODE:
_pynative_executor.del_cell(self)
init_inputs_names = self.__init__.__code__.co_varnames
if "self" not in init_inputs_names:
raise TypeError("For 'Cell', the method '__init__' must have parameter 'self'. ")
# while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
# here using pop(id(self), None) to avoid KeyError exception
cells_compile_cache.pop(id(self), None)
@ -443,6 +446,10 @@ class Cell(Cell_):
f"and loss functions are configured with set_inputs.")
if len(inputs) > positional_args + default_args:
construct_inputs_names = self.construct.__code__.co_varnames
if 'self' not in construct_inputs_names:
raise TypeError(f"For 'Cell', the method 'construct' must have parameter 'self'. ")
raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument and "
f"{default_args} default argument, total {positional_args + default_args}, "
f"but got {len(inputs)}.")
@ -598,8 +605,8 @@ class Cell(Cell_):
def __call__(self, *args, **kwargs):
if self.__class__.construct is Cell.construct:
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
f"it will call the super class(Cell) 'construct'.")
raise AttributeError("For 'Cell', the method 'construct' is not defined. ")
if kwargs:
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
bound_arguments.apply_defaults()

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
import mindspore.nn as nn
@ -49,3 +51,57 @@ def test_cell_call_cell_methods():
x = ms.Tensor(1)
y = ms.Tensor(2)
print(net(x, y))
def test_construct_require_self():
"""
Feature: Support use Cell method and attribute.
Description: Test function construct require self.
Expectation: No exception.
"""
x = ms.Tensor(1)
class ConstructRequireSelf(nn.Cell):
def construct(x):
return x
net = ConstructRequireSelf()
with pytest.raises(TypeError) as info:
net(x)
assert "construct" in str(info.value)
assert "self" in str(info.value)
def test_init_require_self():
"""
Feature: Support use Cell method and attribute.
Description: Test function __init__ require self.
Expectation: No exception.
"""
class InitRequireSelf(nn.Cell):
def __init__():
pass
with pytest.raises(TypeError):
InitRequireSelf()
def test_construct_exist():
"""
Feature: Support use Cell method and attribute.
Description: Test function construct not exist.
Expectation: No exception.
"""
class ConstructNotExist1(nn.Cell):
def cnosrtuct(self):
pass
class ConstructNotExist2(nn.Cell):
pass
net1 = ConstructNotExist1()
with pytest.raises(AttributeError):
net1()
net2 = ConstructNotExist2()
with pytest.raises(AttributeError):
net2()

View File

@ -175,4 +175,7 @@ def test_missing_construct():
np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
tensor = Tensor(np_input)
net = NetMissConstruct()
assert net(tensor) is None
with pytest.raises(AttributeError) as info:
net(tensor)
assert "construct" in str(info.value)
assert "not defined" in str(info.value)