diff --git a/docs/api/api_python/nn/mindspore.nn.Cell.rst b/docs/api/api_python/nn/mindspore.nn.Cell.rst index 7f2bf7f0526..2ccbdec622b 100644 --- a/docs/api/api_python/nn/mindspore.nn.Cell.rst +++ b/docs/api/api_python/nn/mindspore.nn.Cell.rst @@ -27,6 +27,16 @@ 参数: - **flags** (dict) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值:None。 + .. py:method:: apply(fn) + + 递归地将 `fn` 应用于每个子Cell(由 `.cells()` 返回)以及自身。通常用于初始化模型的参数。 + + 参数: + - **fn** (function) - 被执行于每个Cell的function。 + + 返回: + Cell类型,Cell本身。 + .. py:method:: auto_cast_inputs(inputs) 在混合精度下,自动对输入进行类型转换。 diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index afe22c4b396..11d37cce92b 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -1414,6 +1414,38 @@ class Cell(Cell_): if "fp32" in flags and flags.get("fp32", False): self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32) + def apply(self, fn): + """ + Applies fn recursively to every subcell (as returned by .cells()) as well as self. + Typical use includes initializing the parameters of a model. + + Args: + fn (function) – function to be applied to each subcell + + Returns: + Cell, self. + + Examples: + >>> import mindspore.nn as nn + >>> from mindspore.common.initializer import initializer, One + >>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2)) + >>> def func(cell): + ... if isinstance(cell, nn.Dense): + ... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype)) + >>> net.apply(func) + SequentialCell< + (0): Dense + (1): Dense + > + >>> print(net[0].weight.asnumpy()) + [[1. 1.] + [1. 1.]] + """ + for cell in self.cells(): + cell.apply(fn) + fn(self) + return self + def add_flags(self, **flags): """ Add customized attributes for cell. diff --git a/tests/ut/python/nn/test_cell.py b/tests/ut/python/nn/test_cell.py index 086fd077e69..244ed59a7f8 100644 --- a/tests/ut/python/nn/test_cell.py +++ b/tests/ut/python/nn/test_cell.py @@ -22,6 +22,7 @@ import mindspore.nn as nn from mindspore import Tensor, Parameter from mindspore.ops import operations as P from mindspore.common.api import _cell_graph_executor +from mindspore.common.initializer import initializer, One class ModA(nn.Cell): @@ -266,6 +267,25 @@ def test_add_attr(): ModAddCellError(ta) +def test_apply(): + """ + Feature: Cell.apply. + Description: Verify Cell.apply. + Expectation: No exception. + """ + net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2)) + + def func(cell): + if isinstance(cell, nn.Dense): + cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype)) + + net.apply(func) + + target = np.ones((2, 2), ms.dtype_to_nptype(net[0].weight.dtype)) + assert np.allclose(target, net[0].weight.asnumpy()) + assert np.allclose(target, net[1].weight.asnumpy()) + + def test_train_eval(): m = nn.Cell() assert not m.training @@ -305,9 +325,6 @@ def test_cell_names(): class TestKwargsNet(nn.Cell): - def __init__(self): - super(TestKwargsNet, self).__init__() - def construct(self, p1, p2, p3=False, p4=False): if p3: return p1 @@ -315,6 +332,7 @@ class TestKwargsNet(nn.Cell): return P.Add()(p1, p2) return p2 + def test_kwargs_default_value1(): """ Feature: Supports Cell kwargs inputs. @@ -334,7 +352,6 @@ def test_kwargs_default_value2(): Description: Pass kwargs. Expectation: No exception. """ - # Tensor(np.array([1, 2, 3, 4]), ms.float32).reshape((1, 1, 2, 2)) x = Tensor([[[[1.0, 2.0], [3.0, 4.0]]]], ms.float32) nn_op = nn.ResizeBilinear() res = nn_op(x, (4, 4), align_corners=True) diff --git a/tests/ut/python/pynative_mode/nn/test_cell.py b/tests/ut/python/pynative_mode/nn/test_cell.py index 9297670855c..3a238641988 100644 --- a/tests/ut/python/pynative_mode/nn/test_cell.py +++ b/tests/ut/python/pynative_mode/nn/test_cell.py @@ -16,8 +16,10 @@ import numpy as np import pytest +import mindspore as ms import mindspore.nn as nn from mindspore import Tensor, Parameter +from mindspore.common.initializer import initializer, One from ...ut_filter import non_graph_engine @@ -276,6 +278,25 @@ def test_add_attr(): ModAddCellError(ta) +def test_apply(): + """ + Feature: Cell.apply. + Description: Verify Cell.apply. + Expectation: No exception. + """ + net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2)) + + def func(cell): + if isinstance(cell, nn.Dense): + cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype)) + + net.apply(func) + + target = np.ones((2, 2), ms.dtype_to_nptype(net[0].weight.dtype)) + assert np.allclose(target, net[0].weight.asnumpy()) + assert np.allclose(target, net[1].weight.asnumpy()) + + def test_train_eval(): """ test_train_eval """ m = nn.Cell()