support call super when class define in test_case.

This commit is contained in:
buxue 2020-07-23 17:26:08 +08:00
parent 684ff4f46b
commit b812c1a17a
3 changed files with 57 additions and 22 deletions

View File

@ -459,27 +459,27 @@ class Parser:
logger.debug("ops info = %r", ops_info)
return ops_info
def analyze_super(self, father_class_node, subclass_instance):
def analyze_super(self, class_type_node, subclass_instance):
"""Analyze super and return a class instance."""
father_class = None
if father_class_node is None:
father_class = type(subclass_instance)
if isinstance(father_class_node, ast.Name):
father_class_name = getattr(father_class_node, 'id')
father_class = self.global_namespace[father_class_name]
if isinstance(father_class_node, ast.Attribute):
value = getattr(father_class_node, 'value')
attr = getattr(father_class_node, 'attr')
module_name = getattr(value, 'id')
father_class_module = self.global_namespace[module_name]
father_class = getattr(father_class_module, attr)
if father_class is None:
raise ValueError("When call 'super', the father class is None.")
if not isinstance(subclass_instance, father_class):
raise ValueError("When call 'super', the second arg should be an instance of first arg.")
sub_class = type(subclass_instance)
if class_type_node is None:
return super(sub_class, subclass_instance)
if isinstance(class_type_node, ast.Name):
class_name = getattr(class_type_node, 'id')
elif isinstance(class_type_node, ast.Attribute):
class_name = getattr(class_type_node, 'attr')
else:
raise ValueError(f"When call 'super', the first arg should be a class type, "
f"but got {class_type_node.__class__.__name__}.")
target_class_instance = super(father_class, subclass_instance)
return target_class_instance
target_father_class = None
for class_element in sub_class.mro():
if class_element.__name__ == class_name:
target_father_class = class_element
break
if target_father_class is None:
raise ValueError("When call 'super', the second arg should be an instance of first arg.")
return super(target_father_class, subclass_instance)
def get_location(self, node):
"""

View File

@ -58,6 +58,7 @@ class Cell:
>>> def construct(self, x):
>>> return self.relu(x)
"""
def __init__(self, auto_prefix=True, flags=None):
self._params = OrderedDict()
self._cells = OrderedDict()
@ -888,6 +889,7 @@ class Cell:
for param in params:
param.set_param_ps(init_in_server)
class GraphKernel(Cell):
"""
Base class for GraphKernel.
@ -904,6 +906,7 @@ class GraphKernel(Cell):
>>> def construct(self, x):
>>> return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
"""
def __init__(self, auto_prefix=True, pips=None):
super(GraphKernel, self).__init__(auto_prefix, pips)
class_name = self.__class__.__name__

View File

@ -92,7 +92,7 @@ class Net(nn.Cell):
def test_single_super():
single_net = SingleSubNet(2, 3)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
single_net(x, y)
@ -100,7 +100,7 @@ def test_single_super():
def test_mul_super():
mul_net = MulSubNet(2, 3, 4)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
mul_net(x, y)
@ -108,9 +108,41 @@ def test_mul_super():
def test_super_cell():
net = Net(2)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as er:
net(x, y)
assert "Unsupported syntax 'Raise'" in str(er.value)
def test_single_super_in():
class FatherNetIn(nn.Cell):
def __init__(self, x):
super(FatherNetIn, self).__init__(x)
self.x = x
def construct(self, x, y):
return self.x * x
def test_father(self, x):
return self.x + x
class SingleSubNetIN(FatherNetIn):
def __init__(self, x, z):
super(SingleSubNetIN, self).__init__(x)
self.z = z
def construct(self, x, y):
ret_father_construct = super().construct(x, y)
ret_father_test = super(SingleSubNetIN, self).test_father(x)
ret_father_x = super(SingleSubNetIN, self).x
ret_sub_z = self.z
return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z
single_net_in = SingleSubNetIN(2, 3)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.ones([1, 2, 3], np.int32))
y = Tensor(np.ones([1, 2, 3], np.int32))
single_net_in(x, y)