forked from OSSInnovation/mindspore
!3371 support call super when class define in test_case
Merge pull request !3371 from zhangbuxue/support_call_super_when_class_define_in_test_case
This commit is contained in:
commit
366dd44ca8
|
@ -459,27 +459,27 @@ class Parser:
|
|||
logger.debug("ops info = %r", ops_info)
|
||||
return ops_info
|
||||
|
||||
def analyze_super(self, father_class_node, subclass_instance):
|
||||
def analyze_super(self, class_type_node, subclass_instance):
|
||||
"""Analyze super and return a class instance."""
|
||||
father_class = None
|
||||
if father_class_node is None:
|
||||
father_class = type(subclass_instance)
|
||||
if isinstance(father_class_node, ast.Name):
|
||||
father_class_name = getattr(father_class_node, 'id')
|
||||
father_class = self.global_namespace[father_class_name]
|
||||
if isinstance(father_class_node, ast.Attribute):
|
||||
value = getattr(father_class_node, 'value')
|
||||
attr = getattr(father_class_node, 'attr')
|
||||
module_name = getattr(value, 'id')
|
||||
father_class_module = self.global_namespace[module_name]
|
||||
father_class = getattr(father_class_module, attr)
|
||||
if father_class is None:
|
||||
raise ValueError("When call 'super', the father class is None.")
|
||||
if not isinstance(subclass_instance, father_class):
|
||||
raise ValueError("When call 'super', the second arg should be an instance of first arg.")
|
||||
sub_class = type(subclass_instance)
|
||||
if class_type_node is None:
|
||||
return super(sub_class, subclass_instance)
|
||||
if isinstance(class_type_node, ast.Name):
|
||||
class_name = getattr(class_type_node, 'id')
|
||||
elif isinstance(class_type_node, ast.Attribute):
|
||||
class_name = getattr(class_type_node, 'attr')
|
||||
else:
|
||||
raise ValueError(f"When call 'super', the first arg should be a class type, "
|
||||
f"but got {class_type_node.__class__.__name__}.")
|
||||
|
||||
target_class_instance = super(father_class, subclass_instance)
|
||||
return target_class_instance
|
||||
target_father_class = None
|
||||
for class_element in sub_class.mro():
|
||||
if class_element.__name__ == class_name:
|
||||
target_father_class = class_element
|
||||
break
|
||||
if target_father_class is None:
|
||||
raise ValueError("When call 'super', the second arg should be an instance of first arg.")
|
||||
return super(target_father_class, subclass_instance)
|
||||
|
||||
def get_location(self, node):
|
||||
"""
|
||||
|
|
|
@ -58,6 +58,7 @@ class Cell:
|
|||
>>> def construct(self, x):
|
||||
>>> return self.relu(x)
|
||||
"""
|
||||
|
||||
def __init__(self, auto_prefix=True, flags=None):
|
||||
self._params = OrderedDict()
|
||||
self._cells = OrderedDict()
|
||||
|
@ -888,6 +889,7 @@ class Cell:
|
|||
for param in params:
|
||||
param.set_param_ps(init_in_server)
|
||||
|
||||
|
||||
class GraphKernel(Cell):
|
||||
"""
|
||||
Base class for GraphKernel.
|
||||
|
@ -904,6 +906,7 @@ class GraphKernel(Cell):
|
|||
>>> def construct(self, x):
|
||||
>>> return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
|
||||
"""
|
||||
|
||||
def __init__(self, auto_prefix=True, pips=None):
|
||||
super(GraphKernel, self).__init__(auto_prefix, pips)
|
||||
class_name = self.__class__.__name__
|
||||
|
|
|
@ -92,7 +92,7 @@ class Net(nn.Cell):
|
|||
|
||||
def test_single_super():
|
||||
single_net = SingleSubNet(2, 3)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
single_net(x, y)
|
||||
|
@ -100,7 +100,7 @@ def test_single_super():
|
|||
|
||||
def test_mul_super():
|
||||
mul_net = MulSubNet(2, 3, 4)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
mul_net(x, y)
|
||||
|
@ -108,9 +108,41 @@ def test_mul_super():
|
|||
|
||||
def test_super_cell():
|
||||
net = Net(2)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
with pytest.raises(RuntimeError) as er:
|
||||
net(x, y)
|
||||
assert "Unsupported syntax 'Raise'" in str(er.value)
|
||||
|
||||
|
||||
def test_single_super_in():
|
||||
class FatherNetIn(nn.Cell):
|
||||
def __init__(self, x):
|
||||
super(FatherNetIn, self).__init__(x)
|
||||
self.x = x
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.x * x
|
||||
|
||||
def test_father(self, x):
|
||||
return self.x + x
|
||||
|
||||
class SingleSubNetIN(FatherNetIn):
|
||||
def __init__(self, x, z):
|
||||
super(SingleSubNetIN, self).__init__(x)
|
||||
self.z = z
|
||||
|
||||
def construct(self, x, y):
|
||||
ret_father_construct = super().construct(x, y)
|
||||
ret_father_test = super(SingleSubNetIN, self).test_father(x)
|
||||
ret_father_x = super(SingleSubNetIN, self).x
|
||||
ret_sub_z = self.z
|
||||
|
||||
return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z
|
||||
|
||||
single_net_in = SingleSubNetIN(2, 3)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
single_net_in(x, y)
|
||||
|
|
Loading…
Reference in New Issue