diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index a69d62869c2..180edf0fb4c 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -459,27 +459,27 @@ class Parser: logger.debug("ops info = %r", ops_info) return ops_info - def analyze_super(self, father_class_node, subclass_instance): + def analyze_super(self, class_type_node, subclass_instance): """Analyze super and return a class instance.""" - father_class = None - if father_class_node is None: - father_class = type(subclass_instance) - if isinstance(father_class_node, ast.Name): - father_class_name = getattr(father_class_node, 'id') - father_class = self.global_namespace[father_class_name] - if isinstance(father_class_node, ast.Attribute): - value = getattr(father_class_node, 'value') - attr = getattr(father_class_node, 'attr') - module_name = getattr(value, 'id') - father_class_module = self.global_namespace[module_name] - father_class = getattr(father_class_module, attr) - if father_class is None: - raise ValueError("When call 'super', the father class is None.") - if not isinstance(subclass_instance, father_class): - raise ValueError("When call 'super', the second arg should be an instance of first arg.") + sub_class = type(subclass_instance) + if class_type_node is None: + return super(sub_class, subclass_instance) + if isinstance(class_type_node, ast.Name): + class_name = getattr(class_type_node, 'id') + elif isinstance(class_type_node, ast.Attribute): + class_name = getattr(class_type_node, 'attr') + else: + raise ValueError(f"When call 'super', the first arg should be a class type, " + f"but got {class_type_node.__class__.__name__}.") - target_class_instance = super(father_class, subclass_instance) - return target_class_instance + target_father_class = None + for class_element in sub_class.mro(): + if class_element.__name__ == class_name: + target_father_class = class_element + break + if target_father_class is None: + raise ValueError("When call 'super', the second arg should be an instance of first arg.") + return super(target_father_class, subclass_instance) def get_location(self, node): """ diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index d7e18c67fec..723fa213005 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -58,6 +58,7 @@ class Cell: >>> def construct(self, x): >>> return self.relu(x) """ + def __init__(self, auto_prefix=True, flags=None): self._params = OrderedDict() self._cells = OrderedDict() @@ -888,6 +889,7 @@ class Cell: for param in params: param.set_param_ps(init_in_server) + class GraphKernel(Cell): """ Base class for GraphKernel. @@ -904,6 +906,7 @@ class GraphKernel(Cell): >>> def construct(self, x): >>> return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x) """ + def __init__(self, auto_prefix=True, pips=None): super(GraphKernel, self).__init__(auto_prefix, pips) class_name = self.__class__.__name__ diff --git a/tests/ut/python/pipeline/parse/test_super.py b/tests/ut/python/pipeline/parse/test_super.py index eb3bf1682dc..6405b278ae1 100644 --- a/tests/ut/python/pipeline/parse/test_super.py +++ b/tests/ut/python/pipeline/parse/test_super.py @@ -92,7 +92,7 @@ class Net(nn.Cell): def test_single_super(): single_net = SingleSubNet(2, 3) - context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE) x = Tensor(np.ones([1, 2, 3], np.int32)) y = Tensor(np.ones([1, 2, 3], np.int32)) single_net(x, y) @@ -100,7 +100,7 @@ def test_single_super(): def test_mul_super(): mul_net = MulSubNet(2, 3, 4) - context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE) x = Tensor(np.ones([1, 2, 3], np.int32)) y = Tensor(np.ones([1, 2, 3], np.int32)) mul_net(x, y) @@ -108,9 +108,41 @@ def test_mul_super(): def test_super_cell(): net = Net(2) - context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE) x = Tensor(np.ones([1, 2, 3], np.int32)) y = Tensor(np.ones([1, 2, 3], np.int32)) with pytest.raises(RuntimeError) as er: net(x, y) assert "Unsupported syntax 'Raise'" in str(er.value) + + +def test_single_super_in(): + class FatherNetIn(nn.Cell): + def __init__(self, x): + super(FatherNetIn, self).__init__(x) + self.x = x + + def construct(self, x, y): + return self.x * x + + def test_father(self, x): + return self.x + x + + class SingleSubNetIN(FatherNetIn): + def __init__(self, x, z): + super(SingleSubNetIN, self).__init__(x) + self.z = z + + def construct(self, x, y): + ret_father_construct = super().construct(x, y) + ret_father_test = super(SingleSubNetIN, self).test_father(x) + ret_father_x = super(SingleSubNetIN, self).x + ret_sub_z = self.z + + return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z + + single_net_in = SingleSubNetIN(2, 3) + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + x = Tensor(np.ones([1, 2, 3], np.int32)) + y = Tensor(np.ones([1, 2, 3], np.int32)) + single_net_in(x, y)