add cell.apply

This commit is contained in:
fandawei 2022-12-26 16:01:05 +08:00
parent dcb6f6e5c8
commit d871738bb3
4 changed files with 84 additions and 4 deletions

View File

@ -27,6 +27,16 @@
参数:
- **flags** (dict) - Cell的配置信息目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值None。
.. py:method:: apply(fn)
递归地将 `fn` 应用于每个子Cell`.cells()` 返回)以及自身。通常用于初始化模型的参数。
参数:
- **fn** (function) - 被执行于每个Cell的function。
返回:
Cell类型Cell本身。
.. py:method:: auto_cast_inputs(inputs)
在混合精度下,自动对输入进行类型转换。

View File

@ -1414,6 +1414,38 @@ class Cell(Cell_):
if "fp32" in flags and flags.get("fp32", False):
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
def apply(self, fn):
"""
Applies fn recursively to every subcell (as returned by .cells()) as well as self.
Typical use includes initializing the parameters of a model.
Args:
fn (function) function to be applied to each subcell
Returns:
Cell, self.
Examples:
>>> import mindspore.nn as nn
>>> from mindspore.common.initializer import initializer, One
>>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
>>> def func(cell):
... if isinstance(cell, nn.Dense):
... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
>>> net.apply(func)
SequentialCell<
(0): Dense<input_channels=2, output_channels=2, has_bias=True>
(1): Dense<input_channels=2, output_channels=2, has_bias=True>
>
>>> print(net[0].weight.asnumpy())
[[1. 1.]
[1. 1.]]
"""
for cell in self.cells():
cell.apply(fn)
fn(self)
return self
def add_flags(self, **flags):
"""
Add customized attributes for cell.

View File

@ -22,6 +22,7 @@ import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.common.api import _cell_graph_executor
from mindspore.common.initializer import initializer, One
class ModA(nn.Cell):
@ -266,6 +267,25 @@ def test_add_attr():
ModAddCellError(ta)
def test_apply():
"""
Feature: Cell.apply.
Description: Verify Cell.apply.
Expectation: No exception.
"""
net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
def func(cell):
if isinstance(cell, nn.Dense):
cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
net.apply(func)
target = np.ones((2, 2), ms.dtype_to_nptype(net[0].weight.dtype))
assert np.allclose(target, net[0].weight.asnumpy())
assert np.allclose(target, net[1].weight.asnumpy())
def test_train_eval():
m = nn.Cell()
assert not m.training
@ -305,9 +325,6 @@ def test_cell_names():
class TestKwargsNet(nn.Cell):
def __init__(self):
super(TestKwargsNet, self).__init__()
def construct(self, p1, p2, p3=False, p4=False):
if p3:
return p1
@ -315,6 +332,7 @@ class TestKwargsNet(nn.Cell):
return P.Add()(p1, p2)
return p2
def test_kwargs_default_value1():
"""
Feature: Supports Cell kwargs inputs.
@ -334,7 +352,6 @@ def test_kwargs_default_value2():
Description: Pass kwargs.
Expectation: No exception.
"""
# Tensor(np.array([1, 2, 3, 4]), ms.float32).reshape((1, 1, 2, 2))
x = Tensor([[[[1.0, 2.0], [3.0, 4.0]]]], ms.float32)
nn_op = nn.ResizeBilinear()
res = nn_op(x, (4, 4), align_corners=True)

View File

@ -16,8 +16,10 @@
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer, One
from ...ut_filter import non_graph_engine
@ -276,6 +278,25 @@ def test_add_attr():
ModAddCellError(ta)
def test_apply():
"""
Feature: Cell.apply.
Description: Verify Cell.apply.
Expectation: No exception.
"""
net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
def func(cell):
if isinstance(cell, nn.Dense):
cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
net.apply(func)
target = np.ones((2, 2), ms.dtype_to_nptype(net[0].weight.dtype))
assert np.allclose(target, net[0].weight.asnumpy())
assert np.allclose(target, net[1].weight.asnumpy())
def test_train_eval():
""" test_train_eval """
m = nn.Cell()