forked from mindspore-Ecosystem/mindspore
!3371 support call super when class define in test_case
Merge pull request !3371 from zhangbuxue/support_call_super_when_class_define_in_test_case
This commit is contained in:
commit
366dd44ca8
|
@ -459,27 +459,27 @@ class Parser:
|
||||||
logger.debug("ops info = %r", ops_info)
|
logger.debug("ops info = %r", ops_info)
|
||||||
return ops_info
|
return ops_info
|
||||||
|
|
||||||
def analyze_super(self, father_class_node, subclass_instance):
|
def analyze_super(self, class_type_node, subclass_instance):
|
||||||
"""Analyze super and return a class instance."""
|
"""Analyze super and return a class instance."""
|
||||||
father_class = None
|
sub_class = type(subclass_instance)
|
||||||
if father_class_node is None:
|
if class_type_node is None:
|
||||||
father_class = type(subclass_instance)
|
return super(sub_class, subclass_instance)
|
||||||
if isinstance(father_class_node, ast.Name):
|
if isinstance(class_type_node, ast.Name):
|
||||||
father_class_name = getattr(father_class_node, 'id')
|
class_name = getattr(class_type_node, 'id')
|
||||||
father_class = self.global_namespace[father_class_name]
|
elif isinstance(class_type_node, ast.Attribute):
|
||||||
if isinstance(father_class_node, ast.Attribute):
|
class_name = getattr(class_type_node, 'attr')
|
||||||
value = getattr(father_class_node, 'value')
|
else:
|
||||||
attr = getattr(father_class_node, 'attr')
|
raise ValueError(f"When call 'super', the first arg should be a class type, "
|
||||||
module_name = getattr(value, 'id')
|
f"but got {class_type_node.__class__.__name__}.")
|
||||||
father_class_module = self.global_namespace[module_name]
|
|
||||||
father_class = getattr(father_class_module, attr)
|
|
||||||
if father_class is None:
|
|
||||||
raise ValueError("When call 'super', the father class is None.")
|
|
||||||
if not isinstance(subclass_instance, father_class):
|
|
||||||
raise ValueError("When call 'super', the second arg should be an instance of first arg.")
|
|
||||||
|
|
||||||
target_class_instance = super(father_class, subclass_instance)
|
target_father_class = None
|
||||||
return target_class_instance
|
for class_element in sub_class.mro():
|
||||||
|
if class_element.__name__ == class_name:
|
||||||
|
target_father_class = class_element
|
||||||
|
break
|
||||||
|
if target_father_class is None:
|
||||||
|
raise ValueError("When call 'super', the second arg should be an instance of first arg.")
|
||||||
|
return super(target_father_class, subclass_instance)
|
||||||
|
|
||||||
def get_location(self, node):
|
def get_location(self, node):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -58,6 +58,7 @@ class Cell:
|
||||||
>>> def construct(self, x):
|
>>> def construct(self, x):
|
||||||
>>> return self.relu(x)
|
>>> return self.relu(x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, auto_prefix=True, flags=None):
|
def __init__(self, auto_prefix=True, flags=None):
|
||||||
self._params = OrderedDict()
|
self._params = OrderedDict()
|
||||||
self._cells = OrderedDict()
|
self._cells = OrderedDict()
|
||||||
|
@ -888,6 +889,7 @@ class Cell:
|
||||||
for param in params:
|
for param in params:
|
||||||
param.set_param_ps(init_in_server)
|
param.set_param_ps(init_in_server)
|
||||||
|
|
||||||
|
|
||||||
class GraphKernel(Cell):
|
class GraphKernel(Cell):
|
||||||
"""
|
"""
|
||||||
Base class for GraphKernel.
|
Base class for GraphKernel.
|
||||||
|
@ -904,6 +906,7 @@ class GraphKernel(Cell):
|
||||||
>>> def construct(self, x):
|
>>> def construct(self, x):
|
||||||
>>> return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
|
>>> return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, auto_prefix=True, pips=None):
|
def __init__(self, auto_prefix=True, pips=None):
|
||||||
super(GraphKernel, self).__init__(auto_prefix, pips)
|
super(GraphKernel, self).__init__(auto_prefix, pips)
|
||||||
class_name = self.__class__.__name__
|
class_name = self.__class__.__name__
|
||||||
|
|
|
@ -92,7 +92,7 @@ class Net(nn.Cell):
|
||||||
|
|
||||||
def test_single_super():
|
def test_single_super():
|
||||||
single_net = SingleSubNet(2, 3)
|
single_net = SingleSubNet(2, 3)
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||||
single_net(x, y)
|
single_net(x, y)
|
||||||
|
@ -100,7 +100,7 @@ def test_single_super():
|
||||||
|
|
||||||
def test_mul_super():
|
def test_mul_super():
|
||||||
mul_net = MulSubNet(2, 3, 4)
|
mul_net = MulSubNet(2, 3, 4)
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||||
mul_net(x, y)
|
mul_net(x, y)
|
||||||
|
@ -108,9 +108,41 @@ def test_mul_super():
|
||||||
|
|
||||||
def test_super_cell():
|
def test_super_cell():
|
||||||
net = Net(2)
|
net = Net(2)
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||||
with pytest.raises(RuntimeError) as er:
|
with pytest.raises(RuntimeError) as er:
|
||||||
net(x, y)
|
net(x, y)
|
||||||
assert "Unsupported syntax 'Raise'" in str(er.value)
|
assert "Unsupported syntax 'Raise'" in str(er.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_super_in():
|
||||||
|
class FatherNetIn(nn.Cell):
|
||||||
|
def __init__(self, x):
|
||||||
|
super(FatherNetIn, self).__init__(x)
|
||||||
|
self.x = x
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.x * x
|
||||||
|
|
||||||
|
def test_father(self, x):
|
||||||
|
return self.x + x
|
||||||
|
|
||||||
|
class SingleSubNetIN(FatherNetIn):
|
||||||
|
def __init__(self, x, z):
|
||||||
|
super(SingleSubNetIN, self).__init__(x)
|
||||||
|
self.z = z
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
ret_father_construct = super().construct(x, y)
|
||||||
|
ret_father_test = super(SingleSubNetIN, self).test_father(x)
|
||||||
|
ret_father_x = super(SingleSubNetIN, self).x
|
||||||
|
ret_sub_z = self.z
|
||||||
|
|
||||||
|
return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z
|
||||||
|
|
||||||
|
single_net_in = SingleSubNetIN(2, 3)
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||||
|
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||||
|
single_net_in(x, y)
|
||||||
|
|
Loading…
Reference in New Issue